import gradio as gr from PIL import Image import torch from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution # ---- Device ---- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ---- Model IDs / paths ---- PRETRAINED_ID = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" # crystal clear FINETUNED_ID = "swin2sr_div2k_finetuned_x4_1000steps" # smooth (local folder in repo) # ---- Load processors ---- processor_pre = AutoImageProcessor.from_pretrained(PRETRAINED_ID) processor_ft = AutoImageProcessor.from_pretrained(FINETUNED_ID, local_files_only=True) # ---- Load models ---- model_pre = Swin2SRForImageSuperResolution.from_pretrained(PRETRAINED_ID).to(device) model_ft = Swin2SRForImageSuperResolution.from_pretrained(FINETUNED_ID, local_files_only=True).to(device) model_pre.eval() model_ft.eval() # ---- Inference function ---- def swin2sr_upscale(input_image: Image.Image, mode: str): """ Run 4x super-resolution using Swin2SR. mode: "Crystal clear (pretrained)" or "Smooth (fine-tuned)". """ if input_image is None: return None if mode == "Smooth (fine-tuned)": model = model_ft processor = processor_ft else: model = model_pre processor = processor_pre inputs = processor(images=input_image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1) sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy()) sr_image = Image.fromarray(sr_array) return sr_image # ---- Gradio UI ---- with gr.Blocks() as demo: gr.Markdown("# Image Super-Resolution (Swin2SR x4)") gr.Markdown( "Choose **Crystal clear (pretrained)** for the original Swin2SR model, " "or **Smooth (fine-tuned)** for the Swin2SR version we fine-tuned on DIV2K patches." ) with gr.Row(): input_image = gr.Image(type="pil", label="Upload low-res image") output_image = gr.Image(type="pil", label="4x Super-resolved image") mode_dropdown = gr.Dropdown( label="Style", choices=["Crystal clear (pretrained)", "Smooth (fine-tuned)"], value="Crystal clear (pretrained)", interactive=True, ) run_btn = gr.Button("Upscale") run_btn.click( fn=swin2sr_upscale, inputs=[input_image, mode_dropdown], outputs=output_image, ) if __name__ == "__main__": demo.launch()