File size: 2,533 Bytes
d258641
 
 
 
 
c514002
 
d258641
c514002
 
3fb9cf2
d258641
c514002
 
555e888
d258641
c514002
 
555e888
d258641
c514002
 
d258641
c514002
 
d258641
 
c514002
d258641
 
 
 
c514002
 
 
 
 
 
 
 
d258641
 
 
 
 
 
 
 
 
 
 
c514002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d258641
c514002
 
 
 
 
d258641
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from PIL import Image
import torch
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution

# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Model IDs / paths ----
PRETRAINED_ID = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"          # crystal clear
FINETUNED_ID = "swin2sr_div2k_finetuned_x4_1000steps"                     # smooth (local folder in repo)

# ---- Load processors ----
processor_pre = AutoImageProcessor.from_pretrained(PRETRAINED_ID)
processor_ft  = AutoImageProcessor.from_pretrained(FINETUNED_ID, local_files_only=True)

# ---- Load models ----
model_pre = Swin2SRForImageSuperResolution.from_pretrained(PRETRAINED_ID).to(device)
model_ft  = Swin2SRForImageSuperResolution.from_pretrained(FINETUNED_ID, local_files_only=True).to(device)

model_pre.eval()
model_ft.eval()

# ---- Inference function ----
def swin2sr_upscale(input_image: Image.Image, mode: str):
    """
    Run 4x super-resolution using Swin2SR.
    mode: "Crystal clear (pretrained)" or "Smooth (fine-tuned)".
    """
    if input_image is None:
        return None

    if mode == "Smooth (fine-tuned)":
        model = model_ft
        processor = processor_ft
    else:
        model = model_pre
        processor = processor_pre

    inputs = processor(images=input_image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1)
    sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy())
    sr_image = Image.fromarray(sr_array)

    return sr_image

# ---- Gradio UI ----
with gr.Blocks() as demo:
    gr.Markdown("# Image Super-Resolution (Swin2SR x4)")
    gr.Markdown(
        "Choose **Crystal clear (pretrained)** for the original Swin2SR model, "
        "or **Smooth (fine-tuned)** for the Swin2SR version we fine-tuned on DIV2K patches."
    )

    with gr.Row():
        input_image = gr.Image(type="pil", label="Upload low-res image")
        output_image = gr.Image(type="pil", label="4x Super-resolved image")

    mode_dropdown = gr.Dropdown(
        label="Style",
        choices=["Crystal clear (pretrained)", "Smooth (fine-tuned)"],
        value="Crystal clear (pretrained)",
        interactive=True,
    )

    run_btn = gr.Button("Upscale")

    run_btn.click(
        fn=swin2sr_upscale,
        inputs=[input_image, mode_dropdown],
        outputs=output_image,
    )

if __name__ == "__main__":
    demo.launch()