irebmann's picture
Update app.py
3fb9cf2 verified
import gradio as gr
from PIL import Image
import torch
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---- Model IDs / paths ----
PRETRAINED_ID = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" # crystal clear
FINETUNED_ID = "swin2sr_div2k_finetuned_x4_1000steps" # smooth (local folder in repo)
# ---- Load processors ----
processor_pre = AutoImageProcessor.from_pretrained(PRETRAINED_ID)
processor_ft = AutoImageProcessor.from_pretrained(FINETUNED_ID, local_files_only=True)
# ---- Load models ----
model_pre = Swin2SRForImageSuperResolution.from_pretrained(PRETRAINED_ID).to(device)
model_ft = Swin2SRForImageSuperResolution.from_pretrained(FINETUNED_ID, local_files_only=True).to(device)
model_pre.eval()
model_ft.eval()
# ---- Inference function ----
def swin2sr_upscale(input_image: Image.Image, mode: str):
"""
Run 4x super-resolution using Swin2SR.
mode: "Crystal clear (pretrained)" or "Smooth (fine-tuned)".
"""
if input_image is None:
return None
if mode == "Smooth (fine-tuned)":
model = model_ft
processor = processor_ft
else:
model = model_pre
processor = processor_pre
inputs = processor(images=input_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1)
sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy())
sr_image = Image.fromarray(sr_array)
return sr_image
# ---- Gradio UI ----
with gr.Blocks() as demo:
gr.Markdown("# Image Super-Resolution (Swin2SR x4)")
gr.Markdown(
"Choose **Crystal clear (pretrained)** for the original Swin2SR model, "
"or **Smooth (fine-tuned)** for the Swin2SR version we fine-tuned on DIV2K patches."
)
with gr.Row():
input_image = gr.Image(type="pil", label="Upload low-res image")
output_image = gr.Image(type="pil", label="4x Super-resolved image")
mode_dropdown = gr.Dropdown(
label="Style",
choices=["Crystal clear (pretrained)", "Smooth (fine-tuned)"],
value="Crystal clear (pretrained)",
interactive=True,
)
run_btn = gr.Button("Upscale")
run_btn.click(
fn=swin2sr_upscale,
inputs=[input_image, mode_dropdown],
outputs=output_image,
)
if __name__ == "__main__":
demo.launch()