Spaces:
Sleeping
Sleeping
File size: 2,533 Bytes
d258641 c514002 d258641 c514002 3fb9cf2 d258641 c514002 555e888 d258641 c514002 555e888 d258641 c514002 d258641 c514002 d258641 c514002 d258641 c514002 d258641 c514002 d258641 c514002 d258641 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
from PIL import Image
import torch
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---- Model IDs / paths ----
PRETRAINED_ID = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" # crystal clear
FINETUNED_ID = "swin2sr_div2k_finetuned_x4_1000steps" # smooth (local folder in repo)
# ---- Load processors ----
processor_pre = AutoImageProcessor.from_pretrained(PRETRAINED_ID)
processor_ft = AutoImageProcessor.from_pretrained(FINETUNED_ID, local_files_only=True)
# ---- Load models ----
model_pre = Swin2SRForImageSuperResolution.from_pretrained(PRETRAINED_ID).to(device)
model_ft = Swin2SRForImageSuperResolution.from_pretrained(FINETUNED_ID, local_files_only=True).to(device)
model_pre.eval()
model_ft.eval()
# ---- Inference function ----
def swin2sr_upscale(input_image: Image.Image, mode: str):
"""
Run 4x super-resolution using Swin2SR.
mode: "Crystal clear (pretrained)" or "Smooth (fine-tuned)".
"""
if input_image is None:
return None
if mode == "Smooth (fine-tuned)":
model = model_ft
processor = processor_ft
else:
model = model_pre
processor = processor_pre
inputs = processor(images=input_image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
sr_tensor = outputs.reconstruction.squeeze().clamp(0, 1)
sr_array = (sr_tensor.mul(255).byte().cpu().permute(1, 2, 0).numpy())
sr_image = Image.fromarray(sr_array)
return sr_image
# ---- Gradio UI ----
with gr.Blocks() as demo:
gr.Markdown("# Image Super-Resolution (Swin2SR x4)")
gr.Markdown(
"Choose **Crystal clear (pretrained)** for the original Swin2SR model, "
"or **Smooth (fine-tuned)** for the Swin2SR version we fine-tuned on DIV2K patches."
)
with gr.Row():
input_image = gr.Image(type="pil", label="Upload low-res image")
output_image = gr.Image(type="pil", label="4x Super-resolved image")
mode_dropdown = gr.Dropdown(
label="Style",
choices=["Crystal clear (pretrained)", "Smooth (fine-tuned)"],
value="Crystal clear (pretrained)",
interactive=True,
)
run_btn = gr.Button("Upscale")
run_btn.click(
fn=swin2sr_upscale,
inputs=[input_image, mode_dropdown],
outputs=output_image,
)
if __name__ == "__main__":
demo.launch()
|