Restore lightweight warp package for HF Space runtime
Browse files- .gitignore +14 -0
- warp/__init__.py +0 -0
- warp/gradio_app/__init__.py +1 -0
- warp/gradio_app/app-Nick.py +252 -0
- warp/gradio_app/app.py +359 -0
- warp/gradio_app/model_comparison.py +288 -0
- warp/gradio_app/models/__init__.py +5 -0
- warp/gradio_app/models/registry-Nick.py +9 -0
- warp/gradio_app/models/registry.py +5 -0
- warp/gradio_app/models/upscaler.py +211 -0
- warp/gradio_app/upscale_compare_tab.py +231 -0
- warp/inference/__init__.py +10 -0
- warp/inference/background_removal.py +226 -0
- warp/inference/hf_client.py +253 -0
- warp/inference/local_client.py +64 -0
- warp/inference/metrics.py +256 -0
- warp/inference/upscaler.py +117 -0
- warp/models/__init__.py +25 -0
- warp/models/registry.py +248 -0
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# keep warp lightweight
|
| 3 |
+
warp/inference_outputs/
|
| 4 |
+
warp/data/
|
| 5 |
+
warp/**/checkpoints/
|
| 6 |
+
warp/**/*.pt
|
| 7 |
+
warp/**/*.pth
|
| 8 |
+
warp/**/*.ckpt
|
| 9 |
+
warp/**/*.safetensors
|
| 10 |
+
warp/**/*.bin
|
| 11 |
+
warp/**/*.npz
|
| 12 |
+
*.zip
|
| 13 |
+
*.7z
|
| 14 |
+
*.tar
|
warp/__init__.py
ADDED
|
File without changes
|
warp/gradio_app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Gradio application for A360 WARP experimentation UI."""
|
warp/gradio_app/app-Nick.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from supabase import Client as ClientType
|
| 11 |
+
|
| 12 |
+
from warp.data import ImageLoader as ImageLoaderType
|
| 13 |
+
from warp.gradio_app.models.upscaler import ImageUpscaler
|
| 14 |
+
else:
|
| 15 |
+
ClientType = object
|
| 16 |
+
ImageLoaderType = object
|
| 17 |
+
ImageUpscaler = object
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from supabase import Client, create_client
|
| 21 |
+
except Exception:
|
| 22 |
+
create_client = None # type: ignore
|
| 23 |
+
Client = None # type: ignore
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from warp.data import ImageLoader
|
| 27 |
+
except ImportError:
|
| 28 |
+
ImageLoader = None # type: ignore
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from warp.gradio_app.models.upscaler import create_upscaler
|
| 32 |
+
|
| 33 |
+
UPSCALER_AVAILABLE = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
create_upscaler = None # type: ignore
|
| 36 |
+
UPSCALER_AVAILABLE = False
|
| 37 |
+
|
| 38 |
+
load_dotenv()
|
| 39 |
+
SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
|
| 40 |
+
SUPABASE_ANON_KEY: str = os.getenv("SUPABASE_ANON_KEY", "")
|
| 41 |
+
|
| 42 |
+
supabase: ClientType | None = None
|
| 43 |
+
if callable(create_client) and SUPABASE_URL and SUPABASE_ANON_KEY:
|
| 44 |
+
supabase = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
|
| 45 |
+
|
| 46 |
+
# Initialize image loader
|
| 47 |
+
image_loader: ImageLoaderType | None = None
|
| 48 |
+
try:
|
| 49 |
+
if callable(ImageLoader):
|
| 50 |
+
image_loader = ImageLoader()
|
| 51 |
+
print(f"✓ Loaded {len(image_loader.practices)} practices with scraped images")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Warning: Could not initialize ImageLoader: {e}")
|
| 54 |
+
|
| 55 |
+
# Initialize upscaler (lazy load)
|
| 56 |
+
upscaler: ImageUpscaler | None = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_practice_images(practice_name: str) -> tuple[list, str]:
|
| 60 |
+
"""Load sample images from a practice.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
practice_name: Name of the practice to load images from
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Tuple of (list of image paths, status message)
|
| 67 |
+
"""
|
| 68 |
+
if not image_loader:
|
| 69 |
+
return [], "Image loader not available"
|
| 70 |
+
|
| 71 |
+
if not practice_name:
|
| 72 |
+
return [], "Please select a practice"
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
# Get random sample of images
|
| 76 |
+
image_paths = image_loader.get_random_images(practice_name, n=10)
|
| 77 |
+
stats = image_loader.get_practice_stats(practice_name)
|
| 78 |
+
msg = (
|
| 79 |
+
f"Loaded {len(image_paths)} sample images from {practice_name} "
|
| 80 |
+
f"(Total: {stats['total_images']} images)"
|
| 81 |
+
)
|
| 82 |
+
return [str(p) for p in image_paths], msg
|
| 83 |
+
except Exception as e:
|
| 84 |
+
return [], f"Error loading images: {e}"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def run_model(procedure: str | None, notes: str | None) -> str:
|
| 88 |
+
"""Run a placeholder model execution.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
procedure: The selected procedure type
|
| 92 |
+
notes: Additional context or parameters
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
A formatted string with procedure and notes information
|
| 96 |
+
"""
|
| 97 |
+
return f"Procedure={procedure or 'n/a'} | Notes={notes or 'n/a'}"
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def upscale_images(
|
| 101 |
+
before_img, after_img, prompt: str, num_steps: int, guidance: float
|
| 102 |
+
) -> tuple[object, object, str]:
|
| 103 |
+
"""Upscale before/after image pair.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
before_img: Before image from Gradio
|
| 107 |
+
after_img: After image from Gradio
|
| 108 |
+
prompt: Quality prompt for upscaling
|
| 109 |
+
num_steps: Number of inference steps
|
| 110 |
+
guidance: Guidance scale
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Tuple of (upscaled_before, upscaled_after, status_message)
|
| 114 |
+
"""
|
| 115 |
+
global upscaler
|
| 116 |
+
|
| 117 |
+
if not UPSCALER_AVAILABLE:
|
| 118 |
+
return (
|
| 119 |
+
None,
|
| 120 |
+
None,
|
| 121 |
+
"Upscaler not available. Install: pip install torch diffusers transformers",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if before_img is None or after_img is None:
|
| 125 |
+
return None, None, "Please upload both before and after images"
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
# Lazy load upscaler
|
| 129 |
+
if upscaler is None and callable(create_upscaler):
|
| 130 |
+
upscaler = create_upscaler(model_type="sd-x4")
|
| 131 |
+
|
| 132 |
+
# Import PIL here to handle the images
|
| 133 |
+
from PIL import Image
|
| 134 |
+
|
| 135 |
+
# Convert Gradio images to PIL if needed
|
| 136 |
+
if not isinstance(before_img, Image.Image):
|
| 137 |
+
before_img = Image.fromarray(before_img)
|
| 138 |
+
if not isinstance(after_img, Image.Image):
|
| 139 |
+
after_img = Image.fromarray(after_img)
|
| 140 |
+
|
| 141 |
+
# Upscale the pair
|
| 142 |
+
before_upscaled, after_upscaled = upscaler.upscale_pair(
|
| 143 |
+
before_img,
|
| 144 |
+
after_img,
|
| 145 |
+
prompt=prompt,
|
| 146 |
+
num_inference_steps=num_steps,
|
| 147 |
+
guidance_scale=guidance,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return (
|
| 151 |
+
before_upscaled,
|
| 152 |
+
after_upscaled,
|
| 153 |
+
f"✓ Successfully upscaled images 4x (Original: {before_img.size} → Upscaled: {before_upscaled.size})",
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
return None, None, f"Error during upscaling: {str(e)}"
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# Build the UI
|
| 161 |
+
with gr.Blocks(title="A360 WARP — Gradio") as demo:
|
| 162 |
+
gr.Markdown("# A360 WARP — Experimentation UI (MVP)")
|
| 163 |
+
gr.Markdown("Load and experiment with before/after images from scraped medical practices.")
|
| 164 |
+
|
| 165 |
+
# Practice selection and image loading
|
| 166 |
+
with gr.Tab("Image Browser"):
|
| 167 |
+
with gr.Row():
|
| 168 |
+
practice_dropdown = gr.Dropdown(
|
| 169 |
+
label="Select Practice",
|
| 170 |
+
choices=image_loader.practices if image_loader else [],
|
| 171 |
+
value=None,
|
| 172 |
+
)
|
| 173 |
+
load_btn = gr.Button("Load Sample Images", variant="primary")
|
| 174 |
+
|
| 175 |
+
status_text = gr.Textbox(label="Status", interactive=False)
|
| 176 |
+
image_gallery = gr.Gallery(label="Sample Images", show_label=True, columns=5, height="auto")
|
| 177 |
+
|
| 178 |
+
load_btn.click(
|
| 179 |
+
fn=load_practice_images,
|
| 180 |
+
inputs=[practice_dropdown],
|
| 181 |
+
outputs=[image_gallery, status_text],
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Image Enhancement (Upscaling)
|
| 185 |
+
with gr.Tab("Image Enhancement"):
|
| 186 |
+
gr.Markdown(
|
| 187 |
+
"### Upscale Before/After Images\n"
|
| 188 |
+
"Upload medical before/after photos to upscale them 4x using AI. "
|
| 189 |
+
"This improves image quality and detail for better comparison."
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
with gr.Row():
|
| 193 |
+
with gr.Column():
|
| 194 |
+
gr.Markdown("#### Original Images")
|
| 195 |
+
before_input = gr.Image(label="Before Image", type="numpy")
|
| 196 |
+
after_input = gr.Image(label="After Image", type="numpy")
|
| 197 |
+
|
| 198 |
+
with gr.Column():
|
| 199 |
+
gr.Markdown("#### Upscaled Images (4x)")
|
| 200 |
+
before_output = gr.Image(label="Upscaled Before")
|
| 201 |
+
after_output = gr.Image(label="Upscaled After")
|
| 202 |
+
|
| 203 |
+
with gr.Row():
|
| 204 |
+
with gr.Column():
|
| 205 |
+
prompt_input = gr.Textbox(
|
| 206 |
+
label="Quality Prompt",
|
| 207 |
+
value="high quality medical photography, sharp details, professional lighting",
|
| 208 |
+
placeholder="Describe desired image quality...",
|
| 209 |
+
)
|
| 210 |
+
with gr.Column():
|
| 211 |
+
num_steps = gr.Slider(
|
| 212 |
+
minimum=20,
|
| 213 |
+
maximum=100,
|
| 214 |
+
value=50,
|
| 215 |
+
step=5,
|
| 216 |
+
label="Inference Steps (higher = better quality, slower)",
|
| 217 |
+
)
|
| 218 |
+
guidance_scale = gr.Slider(
|
| 219 |
+
minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
upscale_btn = gr.Button("Upscale Images", variant="primary", size="lg")
|
| 223 |
+
upscale_status = gr.Textbox(label="Status", interactive=False)
|
| 224 |
+
|
| 225 |
+
upscale_btn.click(
|
| 226 |
+
fn=upscale_images,
|
| 227 |
+
inputs=[before_input, after_input, prompt_input, num_steps, guidance_scale],
|
| 228 |
+
outputs=[before_output, after_output, upscale_status],
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Model experimentation
|
| 232 |
+
with gr.Tab("Model Experiments"):
|
| 233 |
+
with gr.Row():
|
| 234 |
+
procedure = gr.Dropdown(
|
| 235 |
+
label="Procedure",
|
| 236 |
+
choices=[
|
| 237 |
+
"breast-augmentation",
|
| 238 |
+
"liposuction",
|
| 239 |
+
"rhinoplasty",
|
| 240 |
+
"ftm-top-surgery",
|
| 241 |
+
"coolsculpting",
|
| 242 |
+
],
|
| 243 |
+
value=None,
|
| 244 |
+
)
|
| 245 |
+
notes = gr.Textbox(label="Notes", placeholder="Run context / params…")
|
| 246 |
+
run = gr.Button("Run")
|
| 247 |
+
out = gr.Textbox(label="Output")
|
| 248 |
+
|
| 249 |
+
run.click(run_model, inputs=[procedure, notes], outputs=out)
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
demo.launch()
|
warp/gradio_app/app.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import TYPE_CHECKING
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from supabase import Client as ClientType
|
| 11 |
+
|
| 12 |
+
from warp.data import ImageLoader as ImageLoaderType
|
| 13 |
+
from warp.gradio_app.models.upscaler import ImageUpscaler
|
| 14 |
+
else:
|
| 15 |
+
ClientType = object
|
| 16 |
+
ImageLoaderType = object
|
| 17 |
+
ImageUpscaler = object
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from supabase import Client, create_client
|
| 21 |
+
except Exception:
|
| 22 |
+
create_client = None # type: ignore
|
| 23 |
+
Client = None # type: ignore
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from warp.data import ImageLoader
|
| 27 |
+
except ImportError:
|
| 28 |
+
ImageLoader = None # type: ignore
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from warp.gradio_app.models.upscaler import create_upscaler
|
| 32 |
+
|
| 33 |
+
UPSCALER_AVAILABLE = True
|
| 34 |
+
print("✓ Upscaler module loaded successfully")
|
| 35 |
+
except ImportError as e:
|
| 36 |
+
create_upscaler = None # type: ignore
|
| 37 |
+
UPSCALER_AVAILABLE = False
|
| 38 |
+
print(f"✗ Upscaler import failed: {e}")
|
| 39 |
+
|
| 40 |
+
# Temporarily disable Advanced Upscaling tab due to import issues
|
| 41 |
+
# Will re-enable after fixing module resolution
|
| 42 |
+
COMPARE_TAB_AVAILABLE = False
|
| 43 |
+
build_upscale_compare = None
|
| 44 |
+
|
| 45 |
+
load_dotenv()
|
| 46 |
+
SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
|
| 47 |
+
SUPABASE_ANON_KEY: str = os.getenv("SUPABASE_ANON_KEY", "")
|
| 48 |
+
|
| 49 |
+
supabase: ClientType | None = None
|
| 50 |
+
if callable(create_client) and SUPABASE_URL and SUPABASE_ANON_KEY:
|
| 51 |
+
supabase = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
|
| 52 |
+
|
| 53 |
+
# Initialize image loader
|
| 54 |
+
image_loader: ImageLoaderType | None = None
|
| 55 |
+
try:
|
| 56 |
+
if callable(ImageLoader):
|
| 57 |
+
image_loader = ImageLoader()
|
| 58 |
+
print(f"✓ Loaded {len(image_loader.practices)} practices with scraped images")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
# If initialization fails for any reason (including mocked import errors),
|
| 61 |
+
# fall back to no image loader so the rest of the app can still import.
|
| 62 |
+
image_loader = None
|
| 63 |
+
print(f"Warning: Could not initialize ImageLoader: {e}")
|
| 64 |
+
|
| 65 |
+
# Initialize upscaler (lazy load)
|
| 66 |
+
upscaler: ImageUpscaler | None = None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_practice_images(practice_name: str) -> tuple[list, str]:
|
| 70 |
+
"""Load sample images from a practice.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
practice_name: Name of the practice to load images from
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Tuple of (list of image paths, status message)
|
| 77 |
+
"""
|
| 78 |
+
if not image_loader:
|
| 79 |
+
return [], "Image loader not available"
|
| 80 |
+
|
| 81 |
+
if not practice_name:
|
| 82 |
+
return [], "Please select a practice"
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# Get random sample of images
|
| 86 |
+
image_paths = image_loader.get_random_images(practice_name, n=10)
|
| 87 |
+
stats = image_loader.get_practice_stats(practice_name)
|
| 88 |
+
msg = (
|
| 89 |
+
f"Loaded {len(image_paths)} sample images from {practice_name} "
|
| 90 |
+
f"(Total: {stats['total_images']} images)"
|
| 91 |
+
)
|
| 92 |
+
return [str(p) for p in image_paths], msg
|
| 93 |
+
except Exception as e:
|
| 94 |
+
return [], f"Error loading images: {e}"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def run_model(procedure: str | None, notes: str | None) -> str:
|
| 98 |
+
"""Run a placeholder model execution.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
procedure: The selected procedure type
|
| 102 |
+
notes: Additional context or parameters
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
A formatted string with procedure and notes information
|
| 106 |
+
"""
|
| 107 |
+
return f"Procedure={procedure or 'n/a'} | Notes={notes or 'n/a'}"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def upscale_images(
|
| 111 |
+
before_img, after_img, prompt: str, num_steps: int, guidance: float, progress=gr.Progress()
|
| 112 |
+
):
|
| 113 |
+
"""Upscale before/after image pair (synchronous helper).
|
| 114 |
+
|
| 115 |
+
This version is a simple function (not a generator) so tests can call it
|
| 116 |
+
and make assertions about the returned tuple. The Gradio UI wraps this in
|
| 117 |
+
a streaming function that updates the progress bar and status text.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Tuple of (upscaled_before, upscaled_after, status_message)
|
| 121 |
+
"""
|
| 122 |
+
global upscaler
|
| 123 |
+
|
| 124 |
+
# Handle missing upscaler dependency
|
| 125 |
+
if not UPSCALER_AVAILABLE:
|
| 126 |
+
return (
|
| 127 |
+
None,
|
| 128 |
+
None,
|
| 129 |
+
"Upscaler not available. Install: pip install torch diffusers transformers",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Validate inputs
|
| 133 |
+
if before_img is None or after_img is None:
|
| 134 |
+
return None, None, "Please upload both before and after images"
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
# Lazy load upscaler on first use
|
| 138 |
+
if upscaler is None and callable(create_upscaler):
|
| 139 |
+
upscaler = create_upscaler(model_type="sd-x4")
|
| 140 |
+
|
| 141 |
+
# Import PIL here to handle the images
|
| 142 |
+
from PIL import Image
|
| 143 |
+
|
| 144 |
+
# Convert numpy arrays to PIL Images if needed
|
| 145 |
+
if not isinstance(before_img, Image.Image):
|
| 146 |
+
before_img = Image.fromarray(before_img)
|
| 147 |
+
if not isinstance(after_img, Image.Image):
|
| 148 |
+
after_img = Image.fromarray(after_img)
|
| 149 |
+
|
| 150 |
+
orig_size = before_img.size
|
| 151 |
+
|
| 152 |
+
# Use the pair upscaling helper with a callback that updates the
|
| 153 |
+
# Gradio progress bar more granularly during diffusion steps.
|
| 154 |
+
callback_state = {"phase": "before", "last_step": -1}
|
| 155 |
+
|
| 156 |
+
def progress_callback(step, timestep, latents): # type: ignore[unused-argument]
|
| 157 |
+
"""Update progress bar for each diffusion step.
|
| 158 |
+
|
| 159 |
+
We see steps 0..num_steps-1 for the "before" image first, then
|
| 160 |
+
again for the "after" image. When the step counter resets, we
|
| 161 |
+
switch to the "after" phase and map progress into [0.5, 0.9].
|
| 162 |
+
"""
|
| 163 |
+
try:
|
| 164 |
+
# Detect phase change when step counter resets
|
| 165 |
+
if step < callback_state["last_step"]:
|
| 166 |
+
callback_state["phase"] = "after"
|
| 167 |
+
callback_state["last_step"] = step
|
| 168 |
+
|
| 169 |
+
frac = step / max(num_steps, 1)
|
| 170 |
+
if callback_state["phase"] == "before":
|
| 171 |
+
# Map to [0.1, 0.5]
|
| 172 |
+
pct = 0.1 + 0.4 * frac
|
| 173 |
+
desc = f"Upscaling BEFORE image ({step}/{num_steps})"
|
| 174 |
+
else:
|
| 175 |
+
# Map to [0.5, 0.9]
|
| 176 |
+
pct = 0.5 + 0.4 * frac
|
| 177 |
+
desc = f"Upscaling AFTER image ({step}/{num_steps})"
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
progress(pct, desc=desc)
|
| 181 |
+
except Exception:
|
| 182 |
+
# In tests or non-Gradio contexts, progress may be a no-op
|
| 183 |
+
pass
|
| 184 |
+
except Exception:
|
| 185 |
+
# Never allow progress UI issues to break the core upscaling
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
before_upscaled, after_upscaled = upscaler.upscale_pair(
|
| 189 |
+
before_img,
|
| 190 |
+
after_img,
|
| 191 |
+
prompt=prompt,
|
| 192 |
+
num_inference_steps=num_steps,
|
| 193 |
+
guidance_scale=guidance,
|
| 194 |
+
callback=progress_callback,
|
| 195 |
+
callback_steps=1,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
status = (
|
| 199 |
+
"Successfully upscaled both images 4x\n"
|
| 200 |
+
f"Original: {orig_size[0]}×{orig_size[1]} → "
|
| 201 |
+
f"Upscaled: {before_upscaled.size[0]}×{before_upscaled.size[1]}"
|
| 202 |
+
)
|
| 203 |
+
return before_upscaled, after_upscaled, status
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
# Graceful error handling for tests and UI
|
| 207 |
+
return None, None, f"Error during upscaling: {str(e)}"
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def upscale_images_stream(
|
| 211 |
+
before_img, after_img, prompt: str, num_steps: int, guidance: float, progress=gr.Progress()
|
| 212 |
+
):
|
| 213 |
+
"""Streaming wrapper for ``upscale_images`` used by the Gradio UI.
|
| 214 |
+
|
| 215 |
+
Yields intermediate status updates so the user sees a live progress bar
|
| 216 |
+
and status text while the heavy model runs.
|
| 217 |
+
"""
|
| 218 |
+
# Handle missing upscaler dependency
|
| 219 |
+
if not UPSCALER_AVAILABLE:
|
| 220 |
+
yield (
|
| 221 |
+
None,
|
| 222 |
+
None,
|
| 223 |
+
"Upscaler not available. Install: pip install torch diffusers transformers",
|
| 224 |
+
)
|
| 225 |
+
return
|
| 226 |
+
|
| 227 |
+
# Validate inputs
|
| 228 |
+
if before_img is None or after_img is None:
|
| 229 |
+
yield None, None, "Please upload both before and after images"
|
| 230 |
+
return
|
| 231 |
+
|
| 232 |
+
try:
|
| 233 |
+
# Initial progress
|
| 234 |
+
try:
|
| 235 |
+
progress(0.0, desc="Initializing upscaler...")
|
| 236 |
+
except Exception:
|
| 237 |
+
pass
|
| 238 |
+
yield None, None, "Initializing upscaler..."
|
| 239 |
+
|
| 240 |
+
# Coarse progress while running the model
|
| 241 |
+
try:
|
| 242 |
+
progress(0.3, desc="Upscaling images...")
|
| 243 |
+
except Exception:
|
| 244 |
+
pass
|
| 245 |
+
|
| 246 |
+
before_upscaled, after_upscaled, status = upscale_images(
|
| 247 |
+
before_img, after_img, prompt, num_steps, guidance, progress
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
try:
|
| 251 |
+
progress(1.0, desc="Complete")
|
| 252 |
+
except Exception:
|
| 253 |
+
pass
|
| 254 |
+
|
| 255 |
+
yield before_upscaled, after_upscaled, status
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
yield None, None, f"Error during upscaling: {str(e)}"
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# Build the UI
|
| 262 |
+
with gr.Blocks(title="A360 WARP — Gradio") as demo:
|
| 263 |
+
gr.Markdown("# A360 WARP — Experimentation UI (MVP)")
|
| 264 |
+
gr.Markdown("Load and experiment with before/after images from scraped medical practices.")
|
| 265 |
+
|
| 266 |
+
# Practice selection and image loading
|
| 267 |
+
with gr.Tab("Image Browser"):
|
| 268 |
+
with gr.Row():
|
| 269 |
+
practice_dropdown = gr.Dropdown(
|
| 270 |
+
label="Select Practice",
|
| 271 |
+
choices=image_loader.practices if image_loader else [],
|
| 272 |
+
value=None,
|
| 273 |
+
)
|
| 274 |
+
load_btn = gr.Button("Load Sample Images", variant="primary")
|
| 275 |
+
|
| 276 |
+
status_text = gr.Textbox(label="Status", interactive=False)
|
| 277 |
+
image_gallery = gr.Gallery(label="Sample Images", show_label=True, columns=5, height="auto")
|
| 278 |
+
|
| 279 |
+
load_btn.click(
|
| 280 |
+
fn=load_practice_images,
|
| 281 |
+
inputs=[practice_dropdown],
|
| 282 |
+
outputs=[image_gallery, status_text],
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Image Enhancement (Upscaling)
|
| 286 |
+
with gr.Tab("Image Enhancement"):
|
| 287 |
+
gr.Markdown(
|
| 288 |
+
"### Upscale Before/After Images\n"
|
| 289 |
+
"Upload medical before/after photos to upscale them 4x using AI. "
|
| 290 |
+
"This improves image quality and detail for better comparison."
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
with gr.Row():
|
| 294 |
+
with gr.Column():
|
| 295 |
+
gr.Markdown("#### Original Images")
|
| 296 |
+
before_input = gr.Image(label="Before Image", type="numpy")
|
| 297 |
+
after_input = gr.Image(label="After Image", type="numpy")
|
| 298 |
+
|
| 299 |
+
with gr.Column():
|
| 300 |
+
gr.Markdown("#### Upscaled Images (4x)")
|
| 301 |
+
before_output = gr.Image(label="Upscaled Before")
|
| 302 |
+
after_output = gr.Image(label="Upscaled After")
|
| 303 |
+
|
| 304 |
+
with gr.Row():
|
| 305 |
+
with gr.Column():
|
| 306 |
+
prompt_input = gr.Textbox(
|
| 307 |
+
label="Quality Prompt",
|
| 308 |
+
value="high quality medical photography, sharp details, professional lighting",
|
| 309 |
+
placeholder="Describe desired image quality...",
|
| 310 |
+
)
|
| 311 |
+
with gr.Column():
|
| 312 |
+
num_steps = gr.Slider(
|
| 313 |
+
minimum=20,
|
| 314 |
+
maximum=100,
|
| 315 |
+
value=50,
|
| 316 |
+
step=5,
|
| 317 |
+
label="Inference Steps (higher = better quality, slower)",
|
| 318 |
+
)
|
| 319 |
+
guidance_scale = gr.Slider(
|
| 320 |
+
minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale"
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
upscale_btn = gr.Button("Upscale Images", variant="primary", size="lg")
|
| 324 |
+
upscale_status = gr.Textbox(label="Status", interactive=False)
|
| 325 |
+
|
| 326 |
+
# Use the streaming wrapper so users see live progress/status updates
|
| 327 |
+
upscale_btn.click(
|
| 328 |
+
fn=upscale_images_stream,
|
| 329 |
+
inputs=[before_input, after_input, prompt_input, num_steps, guidance_scale],
|
| 330 |
+
outputs=[before_output, after_output, upscale_status],
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Advanced Upscaling with Comparison
|
| 334 |
+
if COMPARE_TAB_AVAILABLE and build_upscale_compare:
|
| 335 |
+
with gr.Tab("Advanced Upscaling"):
|
| 336 |
+
build_upscale_compare()
|
| 337 |
+
|
| 338 |
+
# Model experimentation
|
| 339 |
+
with gr.Tab("Model Experiments"):
|
| 340 |
+
with gr.Row():
|
| 341 |
+
procedure = gr.Dropdown(
|
| 342 |
+
label="Procedure",
|
| 343 |
+
choices=[
|
| 344 |
+
"breast-augmentation",
|
| 345 |
+
"liposuction",
|
| 346 |
+
"rhinoplasty",
|
| 347 |
+
"ftm-top-surgery",
|
| 348 |
+
"coolsculpting",
|
| 349 |
+
],
|
| 350 |
+
value=None,
|
| 351 |
+
)
|
| 352 |
+
notes = gr.Textbox(label="Notes", placeholder="Run context / params…")
|
| 353 |
+
run = gr.Button("Run")
|
| 354 |
+
out = gr.Textbox(label="Output")
|
| 355 |
+
|
| 356 |
+
run.click(run_model, inputs=[procedure, notes], outputs=out)
|
| 357 |
+
|
| 358 |
+
if __name__ == "__main__":
|
| 359 |
+
demo.launch()
|
warp/gradio_app/model_comparison.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Comparison App - Gradio interface for testing and comparing background removal models.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
# Load environment variables from .env file
|
| 15 |
+
env_path = Path(__file__).parent.parent.parent / ".env"
|
| 16 |
+
load_dotenv(env_path)
|
| 17 |
+
|
| 18 |
+
from warp.data import ImageLoader
|
| 19 |
+
from warp.inference.background_removal import BackgroundRemovalEngine, BackgroundRemovalResult
|
| 20 |
+
from warp.models import list_model_names
|
| 21 |
+
|
| 22 |
+
# Initialize components
|
| 23 |
+
engine = BackgroundRemovalEngine()
|
| 24 |
+
loader = ImageLoader()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_test_image(practice: str) -> Optional[Image.Image]:
|
| 28 |
+
"""Load a random test image from a practice."""
|
| 29 |
+
try:
|
| 30 |
+
images = loader.get_random_images(practice, n=1)
|
| 31 |
+
if images:
|
| 32 |
+
return loader.load_image(images[0])
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Error loading image: {e}")
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def run_comparison(
|
| 39 |
+
image: Image.Image, selected_models: List[str]
|
| 40 |
+
) -> Tuple[Dict, pd.DataFrame, str]:
|
| 41 |
+
"""
|
| 42 |
+
Run background removal comparison across selected models.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
image: Input PIL Image
|
| 46 |
+
selected_models: List of model names to test
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Tuple of (results_dict, metrics_df, summary_text)
|
| 50 |
+
"""
|
| 51 |
+
if not image:
|
| 52 |
+
return {}, pd.DataFrame(), "❌ No image provided"
|
| 53 |
+
|
| 54 |
+
if not selected_models:
|
| 55 |
+
return {}, pd.DataFrame(), "❌ No models selected"
|
| 56 |
+
|
| 57 |
+
print(f"\n{'='*60}")
|
| 58 |
+
print(f"Running comparison with {len(selected_models)} models...")
|
| 59 |
+
print(f"{'='*60}")
|
| 60 |
+
|
| 61 |
+
# Run comparisons
|
| 62 |
+
results = {}
|
| 63 |
+
for model_name in selected_models:
|
| 64 |
+
result = engine.remove_background(image, model_name)
|
| 65 |
+
results[model_name] = result
|
| 66 |
+
|
| 67 |
+
# Prepare outputs for gallery (list of tuples with image and caption)
|
| 68 |
+
output_images = [(res.output_image, f"{name}\n{res.processing_time_ms}ms")
|
| 69 |
+
for name, res in results.items() if res.success]
|
| 70 |
+
|
| 71 |
+
# Create metrics DataFrame
|
| 72 |
+
metrics_data = []
|
| 73 |
+
for name, result in results.items():
|
| 74 |
+
metrics_data.append(
|
| 75 |
+
{
|
| 76 |
+
"Model": name,
|
| 77 |
+
"Status": "✓ Success" if result.success else "✗ Failed",
|
| 78 |
+
"Time (ms)": result.processing_time_ms,
|
| 79 |
+
"Edge Quality": f"{result.edge_quality:.3f}" if result.edge_quality else "-",
|
| 80 |
+
"SSIM": f"{result.ssim:.3f}" if result.ssim else "-",
|
| 81 |
+
"PSNR (dB)": f"{result.psnr:.1f}" if result.psnr else "-",
|
| 82 |
+
"Transparency": f"{result.transparency_coverage:.1%}" if result.transparency_coverage else "-",
|
| 83 |
+
"Quality Score": f"{result.weighted_quality_score:.3f}" if result.weighted_quality_score else "-",
|
| 84 |
+
"Error": result.error_message or "-",
|
| 85 |
+
}
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
metrics_df = pd.DataFrame(metrics_data)
|
| 89 |
+
|
| 90 |
+
# Create summary
|
| 91 |
+
successful = sum(1 for r in results.values() if r.success)
|
| 92 |
+
avg_time = (
|
| 93 |
+
sum(r.processing_time_ms for r in results.values() if r.success) / successful
|
| 94 |
+
if successful > 0
|
| 95 |
+
else 0
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
summary = f"""
|
| 99 |
+
## Comparison Summary
|
| 100 |
+
|
| 101 |
+
- **Models Tested**: {len(selected_models)}
|
| 102 |
+
- **Successful**: {successful}
|
| 103 |
+
- **Failed**: {len(selected_models) - successful}
|
| 104 |
+
- **Average Time**: {avg_time:.0f}ms
|
| 105 |
+
""".strip()
|
| 106 |
+
|
| 107 |
+
return output_images, metrics_df, summary
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def run_single_model(
|
| 111 |
+
image: Image.Image, model_name: str
|
| 112 |
+
) -> Tuple[Optional[Image.Image], str]:
|
| 113 |
+
"""Run a single model and return output + info."""
|
| 114 |
+
if not image:
|
| 115 |
+
return None, "❌ No image provided"
|
| 116 |
+
|
| 117 |
+
result = engine.remove_background(image, model_name)
|
| 118 |
+
|
| 119 |
+
if result.success:
|
| 120 |
+
info = f"""
|
| 121 |
+
### ✓ Success
|
| 122 |
+
|
| 123 |
+
- **Model**: {result.model_name}
|
| 124 |
+
- **Processing Time**: {result.processing_time_ms}ms
|
| 125 |
+
- **Input Size**: {result.input_size[0]}x{result.input_size[1]}
|
| 126 |
+
- **Output Size**: {result.output_size[0]}x{result.output_size[1]}
|
| 127 |
+
""".strip()
|
| 128 |
+
return result.output_image, info
|
| 129 |
+
else:
|
| 130 |
+
info = f"""
|
| 131 |
+
### ✗ Failed
|
| 132 |
+
|
| 133 |
+
- **Model**: {result.model_name}
|
| 134 |
+
- **Error**: {result.error_message}
|
| 135 |
+
- **Time Elapsed**: {result.processing_time_ms}ms
|
| 136 |
+
""".strip()
|
| 137 |
+
return image, info
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ============================================================================
|
| 141 |
+
# Gradio Interface
|
| 142 |
+
# ============================================================================
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def create_comparison_tab():
|
| 146 |
+
"""Create the model comparison tab."""
|
| 147 |
+
with gr.Column():
|
| 148 |
+
gr.Markdown("# 🔬 Model Comparison")
|
| 149 |
+
gr.Markdown("Compare multiple background removal models side-by-side")
|
| 150 |
+
|
| 151 |
+
with gr.Row():
|
| 152 |
+
with gr.Column(scale=1):
|
| 153 |
+
input_image = gr.Image(
|
| 154 |
+
type="pil", label="Input Image", height=400
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Model selection
|
| 158 |
+
available_models = list_model_names("bg_removal")
|
| 159 |
+
model_selector = gr.CheckboxGroup(
|
| 160 |
+
choices=available_models,
|
| 161 |
+
value=[available_models[0]] if available_models else [],
|
| 162 |
+
label="Select Models to Compare",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Quick load options
|
| 166 |
+
with gr.Row():
|
| 167 |
+
practice_dropdown = gr.Dropdown(
|
| 168 |
+
choices=loader.practices if hasattr(loader, "practices") else [],
|
| 169 |
+
label="Load Random Image From",
|
| 170 |
+
value=None,
|
| 171 |
+
)
|
| 172 |
+
load_btn = gr.Button("📁 Load Sample", size="sm")
|
| 173 |
+
|
| 174 |
+
run_btn = gr.Button("▶️ Run Comparison", variant="primary", size="lg")
|
| 175 |
+
|
| 176 |
+
with gr.Column(scale=2):
|
| 177 |
+
summary_md = gr.Markdown("### Ready to compare models")
|
| 178 |
+
metrics_table = gr.DataFrame(label="Performance Metrics")
|
| 179 |
+
|
| 180 |
+
# Output gallery
|
| 181 |
+
gr.Markdown("### Model Outputs")
|
| 182 |
+
output_gallery = gr.Gallery(
|
| 183 |
+
label="Results", columns=3, height="auto", object_fit="contain"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Wire up events
|
| 187 |
+
def load_sample(practice):
|
| 188 |
+
if practice:
|
| 189 |
+
img = load_test_image(practice)
|
| 190 |
+
return img
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
load_btn.click(fn=load_sample, inputs=[practice_dropdown], outputs=[input_image])
|
| 194 |
+
|
| 195 |
+
run_btn.click(
|
| 196 |
+
fn=run_comparison,
|
| 197 |
+
inputs=[input_image, model_selector],
|
| 198 |
+
outputs=[output_gallery, metrics_table, summary_md],
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def create_single_model_tab():
|
| 203 |
+
"""Create the single model testing tab."""
|
| 204 |
+
with gr.Column():
|
| 205 |
+
gr.Markdown("# 🎯 Single Model Test")
|
| 206 |
+
gr.Markdown("Test a single model with detailed results")
|
| 207 |
+
|
| 208 |
+
with gr.Row():
|
| 209 |
+
with gr.Column(scale=1):
|
| 210 |
+
input_image = gr.Image(type="pil", label="Input Image")
|
| 211 |
+
|
| 212 |
+
model_dropdown = gr.Dropdown(
|
| 213 |
+
choices=list_model_names("bg_removal"),
|
| 214 |
+
value=list_model_names("bg_removal")[0],
|
| 215 |
+
label="Select Model",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
run_btn = gr.Button("▶️ Remove Background", variant="primary")
|
| 219 |
+
|
| 220 |
+
with gr.Column(scale=1):
|
| 221 |
+
output_image = gr.Image(type="pil", label="Output Image")
|
| 222 |
+
result_info = gr.Markdown("### Waiting for input...")
|
| 223 |
+
|
| 224 |
+
run_btn.click(
|
| 225 |
+
fn=run_single_model,
|
| 226 |
+
inputs=[input_image, model_dropdown],
|
| 227 |
+
outputs=[output_image, result_info],
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def create_app():
|
| 232 |
+
"""Create the full Gradio app."""
|
| 233 |
+
with gr.Blocks(title="WARP Model Comparison", theme=gr.themes.Soft()) as app:
|
| 234 |
+
gr.Markdown(
|
| 235 |
+
"""
|
| 236 |
+
# 🚀 WARP Model Test Harness
|
| 237 |
+
|
| 238 |
+
**AI-Powered Image Processing Pipeline**
|
| 239 |
+
|
| 240 |
+
Test and compare background removal models with real-time performance metrics.
|
| 241 |
+
"""
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
with gr.Tabs():
|
| 245 |
+
with gr.Tab("Model Comparison"):
|
| 246 |
+
create_comparison_tab()
|
| 247 |
+
|
| 248 |
+
with gr.Tab("Single Model Test"):
|
| 249 |
+
create_single_model_tab()
|
| 250 |
+
|
| 251 |
+
with gr.Tab("Model Registry"):
|
| 252 |
+
gr.Markdown("## 📚 Available Models")
|
| 253 |
+
|
| 254 |
+
# Display model registry
|
| 255 |
+
from warp.models import get_models_by_type
|
| 256 |
+
|
| 257 |
+
bg_models = get_models_by_type("bg_removal")
|
| 258 |
+
model_info_data = []
|
| 259 |
+
|
| 260 |
+
for name, config in bg_models.items():
|
| 261 |
+
model_info_data.append(
|
| 262 |
+
{
|
| 263 |
+
"Name": config.name,
|
| 264 |
+
"Display Name": config.display_name,
|
| 265 |
+
"Model ID": config.model_id,
|
| 266 |
+
"Est. Time": f"{config.estimated_time_ms}ms",
|
| 267 |
+
"Default": "✓" if config.is_default else "",
|
| 268 |
+
"Description": config.description,
|
| 269 |
+
}
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
gr.DataFrame(value=pd.DataFrame(model_info_data), label="Background Removal Models")
|
| 273 |
+
|
| 274 |
+
return app
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# ============================================================================
|
| 278 |
+
# Main Entry Point
|
| 279 |
+
# ============================================================================
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
app = create_app()
|
| 283 |
+
app.launch(
|
| 284 |
+
server_name="127.0.0.1", # Bind to localhost so the browser URL is valid on Windows
|
| 285 |
+
server_port=7860,
|
| 286 |
+
share=False,
|
| 287 |
+
show_error=True,
|
| 288 |
+
)
|
warp/gradio_app/models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model registry for A360 WARP."""
|
| 2 |
+
|
| 3 |
+
from .registry import MODELS
|
| 4 |
+
|
| 5 |
+
__all__ = ["MODELS"]
|
warp/gradio_app/models/registry-Nick.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODELS = {
|
| 2 |
+
"CLIP": "openai/clip-vit-base-patch32",
|
| 3 |
+
"BLIP-2": "Salesforce/blip2-flan-t5-xl",
|
| 4 |
+
"DINOv2": "facebook/dinov2-base",
|
| 5 |
+
}
|
| 6 |
+
|
| 7 |
+
UPSCALER_MODELS = {
|
| 8 |
+
"SD-X4-Upscaler": "stabilityai/stable-diffusion-x4-upscaler",
|
| 9 |
+
}
|
warp/gradio_app/models/registry.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODELS = {
|
| 2 |
+
"CLIP": "openai/clip-vit-base-patch32",
|
| 3 |
+
"BLIP-2": "Salesforce/blip2-flan-t5-xl",
|
| 4 |
+
"DINOv2": "facebook/dinov2-base",
|
| 5 |
+
}
|
warp/gradio_app/models/upscaler.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Image upscaling using HuggingFace models."""
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import TYPE_CHECKING, Literal
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
from diffusers import StableDiffusionUpscalePipeline as PipelineType
|
| 10 |
+
else:
|
| 11 |
+
PipelineType = object
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import torch
|
| 15 |
+
from diffusers import StableDiffusionUpscalePipeline
|
| 16 |
+
|
| 17 |
+
TORCH_AVAILABLE = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
# If either torch or diffusers is missing, mark them unavailable. The
|
| 20 |
+
# tests patch these symbols as needed, and the runtime gracefully degrades
|
| 21 |
+
# by raising a clear ImportError from ImageUpscaler.__init__.
|
| 22 |
+
torch = None # type: ignore[assignment]
|
| 23 |
+
StableDiffusionUpscalePipeline = None # type: ignore[assignment]
|
| 24 |
+
TORCH_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ImageUpscaler:
|
| 28 |
+
"""Handle image upscaling using HuggingFace models."""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self, model_id: str = "stabilityai/stable-diffusion-x4-upscaler", device: str | None = None
|
| 32 |
+
):
|
| 33 |
+
"""Initialize the upscaler.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
model_id: HuggingFace model identifier
|
| 37 |
+
device: Device to run model on ('cuda', 'cpu', or None for auto)
|
| 38 |
+
"""
|
| 39 |
+
if not TORCH_AVAILABLE:
|
| 40 |
+
raise ImportError(
|
| 41 |
+
"torch and diffusers are required for upscaling. "
|
| 42 |
+
"Install with: pip install torch diffusers transformers"
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.model_id = model_id
|
| 46 |
+
|
| 47 |
+
# Auto-detect device
|
| 48 |
+
if device is None:
|
| 49 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
+
else:
|
| 51 |
+
self.device = device
|
| 52 |
+
|
| 53 |
+
self.pipeline: PipelineType | None = None
|
| 54 |
+
self._load_model()
|
| 55 |
+
|
| 56 |
+
def _load_model(self) -> None:
|
| 57 |
+
"""Load the upscaling model."""
|
| 58 |
+
print(f"Loading upscaler model: {self.model_id} on {self.device}...")
|
| 59 |
+
|
| 60 |
+
# Determine torch dtype based on device
|
| 61 |
+
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
|
| 62 |
+
|
| 63 |
+
self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
|
| 64 |
+
self.model_id, torch_dtype=torch_dtype
|
| 65 |
+
)
|
| 66 |
+
self.pipeline = self.pipeline.to(self.device)
|
| 67 |
+
|
| 68 |
+
# Enable memory optimizations if on CUDA
|
| 69 |
+
if self.device == "cuda":
|
| 70 |
+
self.pipeline.enable_attention_slicing()
|
| 71 |
+
|
| 72 |
+
print(f"✓ Model loaded successfully on {self.device}")
|
| 73 |
+
|
| 74 |
+
def upscale(
|
| 75 |
+
self,
|
| 76 |
+
image: Image.Image | str | Path,
|
| 77 |
+
prompt: str = "high quality, detailed, sharp",
|
| 78 |
+
num_inference_steps: int = 50,
|
| 79 |
+
guidance_scale: float = 7.5,
|
| 80 |
+
callback=None,
|
| 81 |
+
callback_steps: int = 1,
|
| 82 |
+
) -> Image.Image:
|
| 83 |
+
"""Upscale an image 4x.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
image: PIL Image or path to image file
|
| 87 |
+
prompt: Text prompt to guide upscaling (helps with quality)
|
| 88 |
+
num_inference_steps: Number of denoising steps (higher = better quality, slower)
|
| 89 |
+
guidance_scale: How closely to follow the prompt (7.5 is good default)
|
| 90 |
+
callback: Optional callback function(step, timestep, latents) called each step
|
| 91 |
+
callback_steps: How often to call the callback (default: every step)
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
Upscaled PIL Image
|
| 95 |
+
"""
|
| 96 |
+
# Load image if path is provided
|
| 97 |
+
if isinstance(image, (str, Path)):
|
| 98 |
+
image = Image.open(image).convert("RGB")
|
| 99 |
+
|
| 100 |
+
# Ensure RGB mode
|
| 101 |
+
if image.mode != "RGB":
|
| 102 |
+
image = image.convert("RGB")
|
| 103 |
+
|
| 104 |
+
# Run upscaling
|
| 105 |
+
if self.pipeline is None:
|
| 106 |
+
raise RuntimeError("Pipeline not initialized")
|
| 107 |
+
result = self.pipeline(
|
| 108 |
+
prompt=prompt,
|
| 109 |
+
image=image,
|
| 110 |
+
num_inference_steps=num_inference_steps,
|
| 111 |
+
guidance_scale=guidance_scale,
|
| 112 |
+
callback=callback,
|
| 113 |
+
callback_steps=callback_steps,
|
| 114 |
+
)
|
| 115 |
+
upscaled: Image.Image = result.images[0]
|
| 116 |
+
|
| 117 |
+
return upscaled
|
| 118 |
+
|
| 119 |
+
def upscale_pair(
|
| 120 |
+
self,
|
| 121 |
+
before_image: Image.Image | str | Path,
|
| 122 |
+
after_image: Image.Image | str | Path,
|
| 123 |
+
prompt: str = "high quality medical photography, sharp details, professional lighting",
|
| 124 |
+
**kwargs,
|
| 125 |
+
) -> tuple[Image.Image, Image.Image]:
|
| 126 |
+
"""Upscale a before/after image pair.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
before_image: Before image (PIL Image or path)
|
| 130 |
+
after_image: After image (PIL Image or path)
|
| 131 |
+
prompt: Text prompt for upscaling quality
|
| 132 |
+
**kwargs: Additional arguments for upscale()
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Tuple of (upscaled_before, upscaled_after)
|
| 136 |
+
"""
|
| 137 |
+
print("Upscaling before image...")
|
| 138 |
+
before_upscaled = self.upscale(before_image, prompt=prompt, **kwargs)
|
| 139 |
+
|
| 140 |
+
print("Upscaling after image...")
|
| 141 |
+
after_upscaled = self.upscale(after_image, prompt=prompt, **kwargs)
|
| 142 |
+
|
| 143 |
+
return before_upscaled, after_upscaled
|
| 144 |
+
|
| 145 |
+
def batch_upscale(
|
| 146 |
+
self,
|
| 147 |
+
images: list[Image.Image | str | Path],
|
| 148 |
+
prompt: str = "high quality, detailed, sharp",
|
| 149 |
+
**kwargs,
|
| 150 |
+
) -> list[Image.Image]:
|
| 151 |
+
"""Upscale multiple images.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
images: List of PIL Images or paths
|
| 155 |
+
prompt: Text prompt for upscaling
|
| 156 |
+
**kwargs: Additional arguments for upscale()
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
List of upscaled PIL Images
|
| 160 |
+
"""
|
| 161 |
+
results = []
|
| 162 |
+
for i, img in enumerate(images, 1):
|
| 163 |
+
print(f"Upscaling image {i}/{len(images)}...")
|
| 164 |
+
upscaled = self.upscale(img, prompt=prompt, **kwargs)
|
| 165 |
+
results.append(upscaled)
|
| 166 |
+
return results
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def create_upscaler(
|
| 170 |
+
model_type: Literal["sd-x4", "fast"] = "sd-x4", device: str | None = None
|
| 171 |
+
) -> ImageUpscaler:
|
| 172 |
+
"""Factory function to create an upscaler.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
model_type: Type of upscaler model
|
| 176 |
+
- "sd-x4": Stable Diffusion 4x upscaler (high quality, slower)
|
| 177 |
+
- "fast": Faster alternative (to be implemented)
|
| 178 |
+
device: Device to run on ('cuda', 'cpu', or None for auto)
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
Initialized ImageUpscaler
|
| 182 |
+
"""
|
| 183 |
+
model_map = {
|
| 184 |
+
"sd-x4": "stabilityai/stable-diffusion-x4-upscaler",
|
| 185 |
+
# Can add more models here later
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
model_id = model_map.get(model_type, model_map["sd-x4"])
|
| 189 |
+
return ImageUpscaler(model_id=model_id, device=device)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# NOTE:
|
| 193 |
+
# -----
|
| 194 |
+
# When this module is imported as a submodule of ``warp.gradio_app.models``
|
| 195 |
+
# (e.g. via ``from warp.gradio_app.models import upscaler``), Python normally
|
| 196 |
+
# caches it as an attribute on the parent package. That caching can interfere
|
| 197 |
+
# with tests that manipulate ``sys.modules`` to simulate import failures
|
| 198 |
+
# (like removing ``torch``/``diffusers`` and re-importing this module).
|
| 199 |
+
#
|
| 200 |
+
# To ensure those tests can reliably exercise the fallback path, we avoid
|
| 201 |
+
# permanently caching this submodule on the parent package by removing the
|
| 202 |
+
# attribute if it exists. The module itself remains available via
|
| 203 |
+
# ``sys.modules['warp.gradio_app.models.upscaler']``.
|
| 204 |
+
try: # Best-effort; never fail import because of this cleanup.
|
| 205 |
+
import sys as _sys
|
| 206 |
+
|
| 207 |
+
_parent_pkg = _sys.modules.get("warp.gradio_app.models")
|
| 208 |
+
if _parent_pkg is not None and hasattr(_parent_pkg, "upscaler"):
|
| 209 |
+
delattr(_parent_pkg, "upscaler")
|
| 210 |
+
except Exception:
|
| 211 |
+
pass
|
warp/gradio_app/upscale_compare_tab.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Advanced upscaling tab with before/after comparison and detailed metrics."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image, ImageDraw
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from .models.upscaler import ImageUpscaler
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from .models.upscaler import create_upscaler
|
| 16 |
+
|
| 17 |
+
UPSCALER_AVAILABLE = True
|
| 18 |
+
except (ImportError, ModuleNotFoundError):
|
| 19 |
+
create_upscaler = None # type: ignore[assignment]
|
| 20 |
+
UPSCALER_AVAILABLE = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Global upscaler instance (lazy load)
|
| 24 |
+
_upscaler: ImageUpscaler | None = None
|
| 25 |
+
|
| 26 |
+
# Configuration
|
| 27 |
+
MAX_INPUT_WIDTH = 1024
|
| 28 |
+
MAX_INPUT_HEIGHT = 1024
|
| 29 |
+
UPSCALE_FACTOR = 4
|
| 30 |
+
QUALITY_PROMPT = "ultra realistic, natural contrast, high clarity, clean skin texture, professional medical photography"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _get_upscaler() -> ImageUpscaler | None:
|
| 34 |
+
"""Lazy load upscaler on first use."""
|
| 35 |
+
global _upscaler
|
| 36 |
+
if _upscaler is None and UPSCALER_AVAILABLE:
|
| 37 |
+
try:
|
| 38 |
+
_upscaler = create_upscaler(model_type="sd-x4")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
raise RuntimeError(f"Failed to load upscaler: {e}") from e
|
| 41 |
+
return _upscaler
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _validate_and_resize_image(
|
| 45 |
+
img: np.ndarray | Image.Image, max_w: int = MAX_INPUT_WIDTH, max_h: int = MAX_INPUT_HEIGHT
|
| 46 |
+
) -> Image.Image:
|
| 47 |
+
"""Convert and validate image, resize if needed."""
|
| 48 |
+
# Convert numpy to PIL if needed
|
| 49 |
+
if isinstance(img, np.ndarray):
|
| 50 |
+
img = Image.fromarray(img.astype("uint8"))
|
| 51 |
+
elif not isinstance(img, Image.Image):
|
| 52 |
+
raise ValueError(f"Invalid image type: {type(img)}")
|
| 53 |
+
|
| 54 |
+
# Convert to RGB
|
| 55 |
+
if img.mode != "RGB":
|
| 56 |
+
img = img.convert("RGB")
|
| 57 |
+
|
| 58 |
+
# Check and resize if too large
|
| 59 |
+
w, h = img.size
|
| 60 |
+
if w > max_w or h > max_h:
|
| 61 |
+
img.thumbnail((max_w, max_h), Image.Resampling.LANCZOS)
|
| 62 |
+
|
| 63 |
+
return img
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _create_comparison_grid(
|
| 67 |
+
before_orig: Image.Image, after_orig: Image.Image, before_up: Image.Image, after_up: Image.Image
|
| 68 |
+
) -> Image.Image:
|
| 69 |
+
"""Create a 2x2 grid showing before/after, original/upscaled."""
|
| 70 |
+
|
| 71 |
+
# All upscaled images should be the same size (4x original)
|
| 72 |
+
# For display, we'll scale them down to fit alongside originals
|
| 73 |
+
|
| 74 |
+
orig_w, orig_h = before_orig.size
|
| 75 |
+
up_display_w, up_display_h = before_up.size # Should be 4x larger
|
| 76 |
+
|
| 77 |
+
# Create display versions of upscaled (scaled down slightly for display)
|
| 78 |
+
display_scale = 0.5 # Show upscaled at 2x (half of 4x)
|
| 79 |
+
display_w = int(up_display_w * display_scale)
|
| 80 |
+
display_h = int(up_display_h * display_scale)
|
| 81 |
+
|
| 82 |
+
before_up_display = before_up.resize((display_w, display_h), Image.Resampling.LANCZOS)
|
| 83 |
+
after_up_display = after_up.resize((display_w, display_h), Image.Resampling.LANCZOS)
|
| 84 |
+
|
| 85 |
+
# Create grid background
|
| 86 |
+
grid_w = display_w * 2 + 40 # padding
|
| 87 |
+
grid_h = display_h * 2 + 80 # padding + title space
|
| 88 |
+
|
| 89 |
+
grid = Image.new("RGB", (grid_w, grid_h), color=(30, 30, 30))
|
| 90 |
+
draw = ImageDraw.Draw(grid)
|
| 91 |
+
|
| 92 |
+
# Add labels (simple text, no font to avoid system dependencies)
|
| 93 |
+
label_y = 10
|
| 94 |
+
draw.text((10, label_y), "BEFORE (Orig → Upscaled 2x)", fill=(255, 150, 0))
|
| 95 |
+
draw.text((display_w + 20, label_y), "AFTER (Orig → Upscaled 2x)", fill=(255, 150, 0))
|
| 96 |
+
|
| 97 |
+
# Paste images
|
| 98 |
+
paste_y = 40
|
| 99 |
+
grid.paste(before_orig, (10, paste_y))
|
| 100 |
+
grid.paste(before_up_display, (10, paste_y + orig_h + 10))
|
| 101 |
+
|
| 102 |
+
grid.paste(after_orig, (display_w + 20, paste_y))
|
| 103 |
+
grid.paste(after_up_display, (display_w + 20, paste_y + orig_h + 10))
|
| 104 |
+
|
| 105 |
+
return grid
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def upscale_and_compare(
|
| 109 |
+
before_img, after_img, prompt: str = QUALITY_PROMPT, num_steps: int = 50, guidance: float = 7.5
|
| 110 |
+
) -> tuple[Image.Image, Image.Image, Image.Image, str]:
|
| 111 |
+
"""Upscale before/after pair with detailed comparison.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
before_img: Before image (numpy or PIL)
|
| 115 |
+
after_img: After image (numpy or PIL)
|
| 116 |
+
prompt: Quality prompt for upscaling
|
| 117 |
+
num_steps: Inference steps (20-100)
|
| 118 |
+
guidance: Guidance scale (1.0-15.0)
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
Tuple of (before_upscaled, after_upscaled, comparison_grid, status_message)
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
if not UPSCALER_AVAILABLE:
|
| 125 |
+
return (
|
| 126 |
+
None,
|
| 127 |
+
None,
|
| 128 |
+
None,
|
| 129 |
+
"❌ Upscaler not available. Install: pip install torch diffusers transformers",
|
| 130 |
+
) # type: ignore[return-value]
|
| 131 |
+
|
| 132 |
+
if before_img is None or after_img is None:
|
| 133 |
+
return None, None, None, "❌ Please upload both before and after images" # type: ignore[return-value]
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
# Get upscaler instance
|
| 137 |
+
upscaler = _get_upscaler()
|
| 138 |
+
if upscaler is None:
|
| 139 |
+
return None, None, None, "❌ Upscaler not available" # type: ignore[return-value]
|
| 140 |
+
|
| 141 |
+
# Validate and resize inputs
|
| 142 |
+
before_pil = _validate_and_resize_image(before_img)
|
| 143 |
+
after_pil = _validate_and_resize_image(after_img)
|
| 144 |
+
|
| 145 |
+
orig_before_size = before_pil.size
|
| 146 |
+
orig_after_size = after_pil.size
|
| 147 |
+
|
| 148 |
+
# Upscale both images
|
| 149 |
+
print(f"Upscaling before image ({orig_before_size})...")
|
| 150 |
+
before_upscaled = upscaler.upscale(
|
| 151 |
+
before_pil, prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
print(f"Upscaling after image ({orig_after_size})...")
|
| 155 |
+
after_upscaled = upscaler.upscale(
|
| 156 |
+
after_pil, prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Create comparison grid
|
| 160 |
+
comparison = _create_comparison_grid(before_pil, after_pil, before_upscaled, after_upscaled)
|
| 161 |
+
|
| 162 |
+
# Build status message
|
| 163 |
+
status = (
|
| 164 |
+
f"✅ Successfully upscaled both images!\n\n"
|
| 165 |
+
f"Before: {orig_before_size} → {before_upscaled.size}\n"
|
| 166 |
+
f"After: {orig_after_size} → {after_upscaled.size}\n"
|
| 167 |
+
f"Upscale Factor: {UPSCALE_FACTOR}x\n"
|
| 168 |
+
f"Steps: {num_steps} | Guidance: {guidance}\n"
|
| 169 |
+
f"\nNote: Comparison shows upscaled images at 2x (50% of 4x for display)"
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return before_upscaled, after_upscaled, comparison, status
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
error_msg = f"❌ Error during upscaling: {str(e)}"
|
| 176 |
+
print(error_msg)
|
| 177 |
+
return None, None, None, error_msg # type: ignore[return-value]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def build_ui() -> None:
|
| 181 |
+
"""Build the advanced upscaling UI tab."""
|
| 182 |
+
|
| 183 |
+
gr.Markdown("### Upscale Before/After Images (Max Detail & Clarity)")
|
| 184 |
+
gr.Markdown(
|
| 185 |
+
"Upload medical before/after photos to upscale them 4x using Stable Diffusion x4 Upscaler. "
|
| 186 |
+
"Both images are processed with identical parameters for fair comparison.\n\n"
|
| 187 |
+
f"⚠️ **Note:** Processing takes 30-60 seconds per image (CPU) or 5-10 seconds (GPU). "
|
| 188 |
+
f"Maximum input size: {MAX_INPUT_WIDTH}x{MAX_INPUT_HEIGHT}px (automatically resized if larger)."
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
with gr.Row():
|
| 192 |
+
with gr.Column():
|
| 193 |
+
gr.Markdown("#### Original Images")
|
| 194 |
+
before_input = gr.Image(label="Before Image", type="numpy")
|
| 195 |
+
after_input = gr.Image(label="After Image", type="numpy")
|
| 196 |
+
|
| 197 |
+
# Parameters
|
| 198 |
+
prompt_input = gr.Textbox(
|
| 199 |
+
label="Quality Prompt",
|
| 200 |
+
value=QUALITY_PROMPT,
|
| 201 |
+
placeholder="Describe desired image quality...",
|
| 202 |
+
lines=3,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
with gr.Column():
|
| 206 |
+
gr.Markdown("#### Upscaled Results (4x)")
|
| 207 |
+
before_output = gr.Image(label="Upscaled Before", type="pil")
|
| 208 |
+
after_output = gr.Image(label="Upscaled After", type="pil")
|
| 209 |
+
|
| 210 |
+
with gr.Row():
|
| 211 |
+
with gr.Column(scale=1):
|
| 212 |
+
num_steps = gr.Slider(
|
| 213 |
+
minimum=20, maximum=100, value=50, step=5, label="Inference Steps"
|
| 214 |
+
)
|
| 215 |
+
with gr.Column(scale=1):
|
| 216 |
+
guidance_scale = gr.Slider(
|
| 217 |
+
minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale"
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
upscale_btn = gr.Button("🚀 Upscale Both", variant="primary", size="lg")
|
| 221 |
+
upscale_status = gr.Textbox(label="Status", interactive=False, lines=4)
|
| 222 |
+
|
| 223 |
+
gr.Markdown("#### Side-by-Side Comparison")
|
| 224 |
+
comparison_output = gr.Image(label="Comparison Grid", type="pil")
|
| 225 |
+
|
| 226 |
+
# Button click handler
|
| 227 |
+
upscale_btn.click(
|
| 228 |
+
fn=upscale_and_compare,
|
| 229 |
+
inputs=[before_input, after_input, prompt_input, num_steps, guidance_scale],
|
| 230 |
+
outputs=[before_output, after_output, comparison_output, upscale_status],
|
| 231 |
+
)
|
warp/inference/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WARP Inference - AI model inference interfaces."""
|
| 2 |
+
|
| 3 |
+
from .hf_client import HuggingFaceAPIError, HuggingFaceClient, create_client, infer_image
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"HuggingFaceClient",
|
| 7 |
+
"HuggingFaceAPIError",
|
| 8 |
+
"create_client",
|
| 9 |
+
"infer_image",
|
| 10 |
+
]
|
warp/inference/background_removal.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Background Removal - Unified interface for background removal models with performance tracking.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import time
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from warp.inference.hf_client import HuggingFaceClient
|
| 12 |
+
from warp.inference.local_client import LocalBackgroundRemovalClient
|
| 13 |
+
from warp.inference.metrics import calculate_comprehensive_metrics, calculate_weighted_quality_score
|
| 14 |
+
from warp.models import get_model, list_model_names
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class BackgroundRemovalResult:
|
| 19 |
+
"""Result of a background removal operation."""
|
| 20 |
+
|
| 21 |
+
output_image: Image.Image
|
| 22 |
+
model_name: str
|
| 23 |
+
processing_time_ms: int
|
| 24 |
+
success: bool = True
|
| 25 |
+
error_message: Optional[str] = None
|
| 26 |
+
input_size: tuple = (0, 0)
|
| 27 |
+
output_size: tuple = (0, 0)
|
| 28 |
+
|
| 29 |
+
# Quality metrics
|
| 30 |
+
edge_quality: Optional[float] = None
|
| 31 |
+
ssim: Optional[float] = None
|
| 32 |
+
psnr: Optional[float] = None
|
| 33 |
+
transparency_coverage: Optional[float] = None
|
| 34 |
+
mask_accuracy: Optional[float] = None
|
| 35 |
+
weighted_quality_score: Optional[float] = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BackgroundRemovalEngine:
|
| 39 |
+
"""Engine for running background removal with multiple models."""
|
| 40 |
+
|
| 41 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 42 |
+
"""
|
| 43 |
+
Initialize background removal engine.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
api_key: Hugging Face API key (optional, defaults to env var)
|
| 47 |
+
"""
|
| 48 |
+
self.hf_client = None
|
| 49 |
+
self.local_clients = {} # Cache local model sessions
|
| 50 |
+
self.api_key = api_key
|
| 51 |
+
self.available_models = list_model_names("bg_removal")
|
| 52 |
+
|
| 53 |
+
def remove_background(
|
| 54 |
+
self, image: Image.Image, model_name: Optional[str] = None
|
| 55 |
+
) -> BackgroundRemovalResult:
|
| 56 |
+
"""
|
| 57 |
+
Remove background from an image using the specified model.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
image: Input PIL Image
|
| 61 |
+
model_name: Model name from registry (uses default if None)
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
BackgroundRemovalResult with output image and metrics
|
| 65 |
+
|
| 66 |
+
Raises:
|
| 67 |
+
ValueError: If model not found
|
| 68 |
+
"""
|
| 69 |
+
# Get model config
|
| 70 |
+
if model_name is None:
|
| 71 |
+
from warp.models import get_default_model
|
| 72 |
+
|
| 73 |
+
model_config = get_default_model("bg_removal")
|
| 74 |
+
if not model_config:
|
| 75 |
+
raise ValueError("No default background removal model found")
|
| 76 |
+
else:
|
| 77 |
+
model_config = get_model(model_name)
|
| 78 |
+
if not model_config:
|
| 79 |
+
raise ValueError(f"Model '{model_name}' not found in registry")
|
| 80 |
+
|
| 81 |
+
input_size = image.size
|
| 82 |
+
start_time = time.time()
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
# Route to appropriate client based on provider
|
| 86 |
+
if model_config.provider == "local":
|
| 87 |
+
# Use local rembg
|
| 88 |
+
if model_config.model_id not in self.local_clients:
|
| 89 |
+
self.local_clients[model_config.model_id] = LocalBackgroundRemovalClient(
|
| 90 |
+
model_name=model_config.model_id
|
| 91 |
+
)
|
| 92 |
+
client = self.local_clients[model_config.model_id]
|
| 93 |
+
output_image = client.remove_background(image)
|
| 94 |
+
else:
|
| 95 |
+
# Use Hugging Face API
|
| 96 |
+
if self.hf_client is None:
|
| 97 |
+
self.hf_client = HuggingFaceClient(api_key=self.api_key)
|
| 98 |
+
output_image = self.hf_client.infer_image(
|
| 99 |
+
model_id=model_config.model_id,
|
| 100 |
+
image=image,
|
| 101 |
+
parameters=model_config.default_parameters,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
processing_time_ms = int((time.time() - start_time) * 1000)
|
| 105 |
+
|
| 106 |
+
# Calculate quality metrics
|
| 107 |
+
metrics = calculate_comprehensive_metrics(
|
| 108 |
+
output_image=output_image,
|
| 109 |
+
input_image=image
|
| 110 |
+
)
|
| 111 |
+
weighted_score = calculate_weighted_quality_score(metrics)
|
| 112 |
+
|
| 113 |
+
return BackgroundRemovalResult(
|
| 114 |
+
output_image=output_image,
|
| 115 |
+
model_name=model_config.name,
|
| 116 |
+
processing_time_ms=processing_time_ms,
|
| 117 |
+
success=True,
|
| 118 |
+
input_size=input_size,
|
| 119 |
+
output_size=output_image.size,
|
| 120 |
+
edge_quality=metrics.get('edge_quality'),
|
| 121 |
+
ssim=metrics.get('ssim'),
|
| 122 |
+
psnr=metrics.get('psnr'),
|
| 123 |
+
transparency_coverage=metrics.get('transparency_coverage'),
|
| 124 |
+
mask_accuracy=metrics.get('mask_accuracy'),
|
| 125 |
+
weighted_quality_score=weighted_score,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
processing_time_ms = int((time.time() - start_time) * 1000)
|
| 130 |
+
return BackgroundRemovalResult(
|
| 131 |
+
output_image=image, # Return original on error
|
| 132 |
+
model_name=model_config.name,
|
| 133 |
+
processing_time_ms=processing_time_ms,
|
| 134 |
+
success=False,
|
| 135 |
+
error_message=str(e),
|
| 136 |
+
input_size=input_size,
|
| 137 |
+
output_size=image.size,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def compare_models(
|
| 141 |
+
self, image: Image.Image, model_names: Optional[list] = None
|
| 142 |
+
) -> Dict[str, BackgroundRemovalResult]:
|
| 143 |
+
"""
|
| 144 |
+
Compare multiple background removal models on the same image.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
image: Input PIL Image
|
| 148 |
+
model_names: List of model names to compare (uses all if None)
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Dictionary mapping model names to results
|
| 152 |
+
"""
|
| 153 |
+
if model_names is None:
|
| 154 |
+
model_names = self.available_models
|
| 155 |
+
|
| 156 |
+
results = {}
|
| 157 |
+
for model_name in model_names:
|
| 158 |
+
print(f"Processing with {model_name}...")
|
| 159 |
+
result = self.remove_background(image, model_name)
|
| 160 |
+
results[model_name] = result
|
| 161 |
+
|
| 162 |
+
if result.success:
|
| 163 |
+
print(f" ✓ {result.processing_time_ms}ms")
|
| 164 |
+
else:
|
| 165 |
+
print(f" ✗ Error: {result.error_message}")
|
| 166 |
+
|
| 167 |
+
return results
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ============================================================================
|
| 171 |
+
# Convenience Functions
|
| 172 |
+
# ============================================================================
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def remove_background(
|
| 176 |
+
image: Image.Image, model_name: Optional[str] = None, api_key: Optional[str] = None
|
| 177 |
+
) -> Image.Image:
|
| 178 |
+
"""
|
| 179 |
+
Quick background removal function.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
image: Input PIL Image
|
| 183 |
+
model_name: Model name (uses default if None)
|
| 184 |
+
api_key: Optional API key
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Output PIL Image with background removed
|
| 188 |
+
"""
|
| 189 |
+
engine = BackgroundRemovalEngine(api_key=api_key)
|
| 190 |
+
result = engine.remove_background(image, model_name)
|
| 191 |
+
|
| 192 |
+
if not result.success:
|
| 193 |
+
raise RuntimeError(f"Background removal failed: {result.error_message}")
|
| 194 |
+
|
| 195 |
+
return result.output_image
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
# Demo/test code
|
| 200 |
+
import os
|
| 201 |
+
|
| 202 |
+
print("=== Background Removal Engine ===\n")
|
| 203 |
+
|
| 204 |
+
engine = BackgroundRemovalEngine()
|
| 205 |
+
print(f"Available models: {', '.join(engine.available_models)}\n")
|
| 206 |
+
|
| 207 |
+
# Test with a sample image if available
|
| 208 |
+
test_image_path = "data/scrapedimages/drleedy.com"
|
| 209 |
+
if os.path.exists(test_image_path):
|
| 210 |
+
from warp.data import ImageLoader
|
| 211 |
+
|
| 212 |
+
loader = ImageLoader()
|
| 213 |
+
images = loader.get_random_images("drleedy.com", n=1)
|
| 214 |
+
|
| 215 |
+
if images:
|
| 216 |
+
print(f"Testing with image: {images[0]}")
|
| 217 |
+
test_image = loader.load_image(images[0])
|
| 218 |
+
|
| 219 |
+
# Test single model
|
| 220 |
+
result = engine.remove_background(test_image, model_name="rmbg-1.4")
|
| 221 |
+
print(f"\nTest result:")
|
| 222 |
+
print(f" Model: {result.model_name}")
|
| 223 |
+
print(f" Time: {result.processing_time_ms}ms")
|
| 224 |
+
print(f" Success: {result.success}")
|
| 225 |
+
else:
|
| 226 |
+
print("No test images found. Run with actual images to test.")
|
warp/inference/hf_client.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face API Client - Unified interface for calling Hugging Face Inference API
|
| 3 |
+
with retry logic, rate limiting, and error handling.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
import requests
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
# Default configuration
|
| 15 |
+
DEFAULT_API_URL = "https://api-inference.huggingface.co/models"
|
| 16 |
+
DEFAULT_TIMEOUT = 120 # seconds
|
| 17 |
+
DEFAULT_MAX_RETRIES = 3
|
| 18 |
+
DEFAULT_RETRY_DELAY = 2 # seconds
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class HuggingFaceAPIError(Exception):
|
| 22 |
+
"""Custom exception for Hugging Face API errors."""
|
| 23 |
+
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class HuggingFaceClient:
|
| 28 |
+
"""Client for Hugging Face Inference API."""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
api_key: Optional[str] = None,
|
| 33 |
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
| 34 |
+
retry_delay: int = DEFAULT_RETRY_DELAY,
|
| 35 |
+
timeout: int = DEFAULT_TIMEOUT,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Initialize Hugging Face API client.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
api_key: Hugging Face API key (defaults to HF_API_KEY env var)
|
| 42 |
+
max_retries: Maximum number of retry attempts
|
| 43 |
+
retry_delay: Delay between retries in seconds
|
| 44 |
+
timeout: Request timeout in seconds
|
| 45 |
+
"""
|
| 46 |
+
self.api_key = api_key or os.getenv("HF_API_KEY")
|
| 47 |
+
if not self.api_key:
|
| 48 |
+
raise ValueError(
|
| 49 |
+
"Hugging Face API key not provided. "
|
| 50 |
+
"Set HF_API_KEY environment variable or pass api_key parameter."
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.max_retries = max_retries
|
| 54 |
+
self.retry_delay = retry_delay
|
| 55 |
+
self.timeout = timeout
|
| 56 |
+
|
| 57 |
+
self.headers = {"Authorization": f"Bearer {self.api_key}"}
|
| 58 |
+
|
| 59 |
+
def _build_url(self, model_id: str) -> str:
|
| 60 |
+
"""Build the full API URL for a model."""
|
| 61 |
+
return f"{DEFAULT_API_URL}/{model_id}"
|
| 62 |
+
|
| 63 |
+
def _handle_response(self, response: requests.Response) -> Any:
|
| 64 |
+
"""
|
| 65 |
+
Handle API response and raise appropriate errors.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
response: Response from API
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Response content (bytes or dict)
|
| 72 |
+
|
| 73 |
+
Raises:
|
| 74 |
+
HuggingFaceAPIError: If API returns an error
|
| 75 |
+
"""
|
| 76 |
+
if response.status_code == 200:
|
| 77 |
+
# Check content type
|
| 78 |
+
content_type = response.headers.get("content-type", "")
|
| 79 |
+
if "application/json" in content_type:
|
| 80 |
+
return response.json()
|
| 81 |
+
else:
|
| 82 |
+
return response.content
|
| 83 |
+
|
| 84 |
+
# Handle errors
|
| 85 |
+
error_msg = f"API request failed with status {response.status_code}"
|
| 86 |
+
try:
|
| 87 |
+
error_detail = response.json()
|
| 88 |
+
if "error" in error_detail:
|
| 89 |
+
error_msg = f"{error_msg}: {error_detail['error']}"
|
| 90 |
+
except Exception:
|
| 91 |
+
error_msg = f"{error_msg}: {response.text[:200]}"
|
| 92 |
+
|
| 93 |
+
raise HuggingFaceAPIError(error_msg)
|
| 94 |
+
|
| 95 |
+
def infer(
|
| 96 |
+
self,
|
| 97 |
+
model_id: str,
|
| 98 |
+
inputs: Any,
|
| 99 |
+
parameters: Optional[Dict] = None,
|
| 100 |
+
return_json: bool = False,
|
| 101 |
+
) -> Any:
|
| 102 |
+
"""
|
| 103 |
+
Call Hugging Face Inference API with retry logic.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
model_id: Hugging Face model ID (e.g., 'briaai/RMBG-2.0')
|
| 107 |
+
inputs: Input data (can be bytes, PIL Image, or dict)
|
| 108 |
+
parameters: Additional parameters for the model
|
| 109 |
+
return_json: Whether to expect JSON response
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
API response (bytes for images, dict for JSON)
|
| 113 |
+
|
| 114 |
+
Raises:
|
| 115 |
+
HuggingFaceAPIError: If all retries fail
|
| 116 |
+
"""
|
| 117 |
+
url = self._build_url(model_id)
|
| 118 |
+
payload = self._prepare_payload(inputs, parameters)
|
| 119 |
+
|
| 120 |
+
for attempt in range(1, self.max_retries + 1):
|
| 121 |
+
try:
|
| 122 |
+
response = requests.post(
|
| 123 |
+
url, headers=self.headers, data=payload, timeout=self.timeout
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Handle model loading
|
| 127 |
+
if response.status_code == 503:
|
| 128 |
+
error_data = response.json()
|
| 129 |
+
if "estimated_time" in error_data:
|
| 130 |
+
wait_time = error_data["estimated_time"]
|
| 131 |
+
print(f"Model loading, waiting {wait_time}s...")
|
| 132 |
+
time.sleep(wait_time + 1)
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
return self._handle_response(response)
|
| 136 |
+
|
| 137 |
+
except requests.exceptions.RequestException as e:
|
| 138 |
+
if attempt == self.max_retries:
|
| 139 |
+
raise HuggingFaceAPIError(f"Request failed after {attempt} attempts: {e}")
|
| 140 |
+
|
| 141 |
+
print(f"Attempt {attempt} failed: {e}. Retrying in {self.retry_delay}s...")
|
| 142 |
+
time.sleep(self.retry_delay)
|
| 143 |
+
|
| 144 |
+
raise HuggingFaceAPIError(f"All {self.max_retries} retries exhausted")
|
| 145 |
+
|
| 146 |
+
def _prepare_payload(self, inputs: Any, parameters: Optional[Dict] = None) -> bytes:
|
| 147 |
+
"""
|
| 148 |
+
Prepare payload for API request.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
inputs: Input data (bytes, PIL Image, or dict)
|
| 152 |
+
parameters: Additional parameters
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Bytes payload
|
| 156 |
+
"""
|
| 157 |
+
# Convert PIL Image to bytes
|
| 158 |
+
if isinstance(inputs, Image.Image):
|
| 159 |
+
buffer = BytesIO()
|
| 160 |
+
inputs.save(buffer, format="PNG")
|
| 161 |
+
inputs = buffer.getvalue()
|
| 162 |
+
|
| 163 |
+
# If already bytes, return as-is
|
| 164 |
+
if isinstance(inputs, bytes):
|
| 165 |
+
return inputs
|
| 166 |
+
|
| 167 |
+
# For dict/json inputs, serialize
|
| 168 |
+
if isinstance(inputs, dict):
|
| 169 |
+
import json
|
| 170 |
+
|
| 171 |
+
return json.dumps(inputs).encode("utf-8")
|
| 172 |
+
|
| 173 |
+
raise ValueError(f"Unsupported input type: {type(inputs)}")
|
| 174 |
+
|
| 175 |
+
def infer_image(self, model_id: str, image: Image.Image, parameters: Optional[Dict] = None) -> Image.Image:
|
| 176 |
+
"""
|
| 177 |
+
Call inference API with a PIL Image and return a PIL Image.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
model_id: Hugging Face model ID
|
| 181 |
+
image: Input PIL Image
|
| 182 |
+
parameters: Additional parameters
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Output PIL Image
|
| 186 |
+
|
| 187 |
+
Raises:
|
| 188 |
+
HuggingFaceAPIError: If inference fails
|
| 189 |
+
"""
|
| 190 |
+
response_bytes = self.infer(model_id, image, parameters)
|
| 191 |
+
return Image.open(BytesIO(response_bytes))
|
| 192 |
+
|
| 193 |
+
def health_check(self, model_id: str) -> bool:
|
| 194 |
+
"""
|
| 195 |
+
Check if a model is available and loaded.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
model_id: Hugging Face model ID
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
True if model is accessible
|
| 202 |
+
"""
|
| 203 |
+
try:
|
| 204 |
+
url = self._build_url(model_id)
|
| 205 |
+
response = requests.get(url, headers=self.headers, timeout=10)
|
| 206 |
+
return response.status_code in [200, 503] # 503 means loading
|
| 207 |
+
except Exception:
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ============================================================================
|
| 212 |
+
# Convenience Functions
|
| 213 |
+
# ============================================================================
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def create_client(api_key: Optional[str] = None) -> HuggingFaceClient:
|
| 217 |
+
"""Create a HuggingFace client with default settings."""
|
| 218 |
+
return HuggingFaceClient(api_key=api_key)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def infer_image(
|
| 222 |
+
model_id: str, image: Image.Image, api_key: Optional[str] = None, parameters: Optional[Dict] = None
|
| 223 |
+
) -> Image.Image:
|
| 224 |
+
"""
|
| 225 |
+
One-shot image inference with a Hugging Face model.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
model_id: Hugging Face model ID
|
| 229 |
+
image: Input PIL Image
|
| 230 |
+
api_key: Optional API key (defaults to env var)
|
| 231 |
+
parameters: Additional parameters
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Output PIL Image
|
| 235 |
+
"""
|
| 236 |
+
client = create_client(api_key)
|
| 237 |
+
return client.infer_image(model_id, image, parameters)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
# Demo/test code
|
| 242 |
+
client = create_client()
|
| 243 |
+
|
| 244 |
+
# Test model availability
|
| 245 |
+
test_model = "briaai/RMBG-1.4"
|
| 246 |
+
print(f"Testing connection to {test_model}...")
|
| 247 |
+
is_available = client.health_check(test_model)
|
| 248 |
+
print(f"Model available: {is_available}")
|
| 249 |
+
|
| 250 |
+
if is_available:
|
| 251 |
+
print("\n✓ Hugging Face API client ready")
|
| 252 |
+
else:
|
| 253 |
+
print("\n✗ Failed to connect to model")
|
warp/inference/local_client.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Local Background Removal Client - Uses rembg for local processing.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from rembg import remove, new_session
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LocalBackgroundRemovalClient:
|
| 11 |
+
"""Client for local background removal using rembg."""
|
| 12 |
+
|
| 13 |
+
# Available models in rembg
|
| 14 |
+
MODELS = {
|
| 15 |
+
"u2net": "General purpose model (fast)",
|
| 16 |
+
"u2netp": "Lightweight version of u2net",
|
| 17 |
+
"u2net_human_seg": "Optimized for human segmentation",
|
| 18 |
+
"u2net_cloth_seg": "Optimized for cloth segmentation",
|
| 19 |
+
"silueta": "General segmentation",
|
| 20 |
+
"isnet-general-use": "Improved segmentation (recommended)",
|
| 21 |
+
"isnet-anime": "Anime/illustration specific",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
def __init__(self, model_name: str = "isnet-general-use"):
|
| 25 |
+
"""
|
| 26 |
+
Initialize local client with a specific model.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
model_name: Name of the rembg model to use
|
| 30 |
+
"""
|
| 31 |
+
self.model_name = model_name
|
| 32 |
+
self.session = None
|
| 33 |
+
|
| 34 |
+
def _get_session(self):
|
| 35 |
+
"""Lazy-load the model session."""
|
| 36 |
+
if self.session is None:
|
| 37 |
+
self.session = new_session(self.model_name)
|
| 38 |
+
return self.session
|
| 39 |
+
|
| 40 |
+
def remove_background(self, image: Image.Image) -> Image.Image:
|
| 41 |
+
"""
|
| 42 |
+
Remove background from an image.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
image: Input PIL Image
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
PIL Image with background removed (RGBA)
|
| 49 |
+
"""
|
| 50 |
+
session = self._get_session()
|
| 51 |
+
output = remove(image, session=session)
|
| 52 |
+
return output
|
| 53 |
+
|
| 54 |
+
@classmethod
|
| 55 |
+
def list_models(cls):
|
| 56 |
+
"""List available models."""
|
| 57 |
+
return cls.MODELS
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
# Test
|
| 62 |
+
print("Available rembg models:")
|
| 63 |
+
for name, desc in LocalBackgroundRemovalClient.MODELS.items():
|
| 64 |
+
print(f" - {name}: {desc}")
|
warp/inference/metrics.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quality Metrics - Calculate image quality metrics for background removal evaluation.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from typing import Dict, Tuple, Optional
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def calculate_edge_quality(image: Image.Image, mask: Optional[Image.Image] = None) -> float:
|
| 12 |
+
"""
|
| 13 |
+
Calculate edge quality score based on gradient strength.
|
| 14 |
+
|
| 15 |
+
Higher score = sharper edges.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
image: PIL Image (RGBA for background removal)
|
| 19 |
+
mask: Optional mask (if available separately)
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Edge quality score (0.0 to 1.0)
|
| 23 |
+
"""
|
| 24 |
+
# Convert to numpy array
|
| 25 |
+
if image.mode == 'RGBA':
|
| 26 |
+
# Use alpha channel as mask
|
| 27 |
+
img_array = np.array(image)
|
| 28 |
+
mask_array = img_array[:, :, 3]
|
| 29 |
+
elif mask:
|
| 30 |
+
mask_array = np.array(mask.convert('L'))
|
| 31 |
+
else:
|
| 32 |
+
# Convert to grayscale if no alpha
|
| 33 |
+
mask_array = np.array(image.convert('L'))
|
| 34 |
+
|
| 35 |
+
# Calculate gradients using Sobel
|
| 36 |
+
grad_x = cv2.Sobel(mask_array, cv2.CV_64F, 1, 0, ksize=3)
|
| 37 |
+
grad_y = cv2.Sobel(mask_array, cv2.CV_64F, 0, 1, ksize=3)
|
| 38 |
+
|
| 39 |
+
# Gradient magnitude
|
| 40 |
+
gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
|
| 41 |
+
|
| 42 |
+
# Normalize to 0-1
|
| 43 |
+
if gradient_magnitude.max() > 0:
|
| 44 |
+
edge_score = np.mean(gradient_magnitude) / gradient_magnitude.max()
|
| 45 |
+
else:
|
| 46 |
+
edge_score = 0.0
|
| 47 |
+
|
| 48 |
+
return float(edge_score)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def calculate_mask_accuracy(pred_mask: Image.Image, gt_mask: Image.Image) -> float:
|
| 52 |
+
"""
|
| 53 |
+
Calculate mask accuracy (IoU - Intersection over Union).
|
| 54 |
+
|
| 55 |
+
Only applicable when ground truth mask is available.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
pred_mask: Predicted mask (alpha channel or grayscale)
|
| 59 |
+
gt_mask: Ground truth mask
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
IoU score (0.0 to 1.0)
|
| 63 |
+
"""
|
| 64 |
+
# Convert to binary masks
|
| 65 |
+
pred_array = np.array(pred_mask.convert('L')) > 127
|
| 66 |
+
gt_array = np.array(gt_mask.convert('L')) > 127
|
| 67 |
+
|
| 68 |
+
# Calculate IoU
|
| 69 |
+
intersection = np.logical_and(pred_array, gt_array).sum()
|
| 70 |
+
union = np.logical_or(pred_array, gt_array).sum()
|
| 71 |
+
|
| 72 |
+
if union == 0:
|
| 73 |
+
return 0.0
|
| 74 |
+
|
| 75 |
+
iou = intersection / union
|
| 76 |
+
return float(iou)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def calculate_ssim(image1: Image.Image, image2: Image.Image) -> float:
|
| 80 |
+
"""
|
| 81 |
+
Calculate Structural Similarity Index (SSIM) between two images.
|
| 82 |
+
|
| 83 |
+
Used to compare processed image with reference.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
image1: First image
|
| 87 |
+
image2: Second image
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
SSIM score (-1.0 to 1.0, higher is better)
|
| 91 |
+
"""
|
| 92 |
+
from skimage.metrics import structural_similarity as ssim
|
| 93 |
+
|
| 94 |
+
# Convert to grayscale numpy arrays
|
| 95 |
+
img1_gray = np.array(image1.convert('L'))
|
| 96 |
+
img2_gray = np.array(image2.convert('L'))
|
| 97 |
+
|
| 98 |
+
# Ensure same size
|
| 99 |
+
if img1_gray.shape != img2_gray.shape:
|
| 100 |
+
# Resize to match
|
| 101 |
+
img2_gray = cv2.resize(img2_gray, (img1_gray.shape[1], img1_gray.shape[0]))
|
| 102 |
+
|
| 103 |
+
# Calculate SSIM
|
| 104 |
+
score = ssim(img1_gray, img2_gray)
|
| 105 |
+
return float(score)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def calculate_psnr(image1: Image.Image, image2: Image.Image) -> float:
|
| 109 |
+
"""
|
| 110 |
+
Calculate Peak Signal-to-Noise Ratio (PSNR) between two images.
|
| 111 |
+
|
| 112 |
+
Higher PSNR = better quality (less noise/distortion).
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
image1: First image
|
| 116 |
+
image2: Second image
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
PSNR in dB (higher is better, typically 20-50)
|
| 120 |
+
"""
|
| 121 |
+
# Convert to numpy arrays
|
| 122 |
+
img1_array = np.array(image1.convert('RGB')).astype(float)
|
| 123 |
+
img2_array = np.array(image2.convert('RGB')).astype(float)
|
| 124 |
+
|
| 125 |
+
# Ensure same size
|
| 126 |
+
if img1_array.shape != img2_array.shape:
|
| 127 |
+
img2_array = cv2.resize(img2_array, (img1_array.shape[1], img1_array.shape[0]))
|
| 128 |
+
|
| 129 |
+
# Calculate MSE
|
| 130 |
+
mse = np.mean((img1_array - img2_array) ** 2)
|
| 131 |
+
|
| 132 |
+
if mse == 0:
|
| 133 |
+
return 100.0 # Perfect match
|
| 134 |
+
|
| 135 |
+
# Calculate PSNR
|
| 136 |
+
max_pixel = 255.0
|
| 137 |
+
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
|
| 138 |
+
|
| 139 |
+
return float(psnr)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def calculate_transparency_coverage(image: Image.Image) -> float:
|
| 143 |
+
"""
|
| 144 |
+
Calculate percentage of image that is transparent (for RGBA images).
|
| 145 |
+
|
| 146 |
+
Useful for background removal evaluation.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
image: RGBA PIL Image
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Transparency coverage (0.0 to 1.0)
|
| 153 |
+
"""
|
| 154 |
+
if image.mode != 'RGBA':
|
| 155 |
+
return 0.0
|
| 156 |
+
|
| 157 |
+
img_array = np.array(image)
|
| 158 |
+
alpha_channel = img_array[:, :, 3]
|
| 159 |
+
|
| 160 |
+
# Count transparent pixels (alpha < 10)
|
| 161 |
+
transparent_pixels = (alpha_channel < 10).sum()
|
| 162 |
+
total_pixels = alpha_channel.size
|
| 163 |
+
|
| 164 |
+
coverage = transparent_pixels / total_pixels
|
| 165 |
+
return float(coverage)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def calculate_comprehensive_metrics(
|
| 169 |
+
output_image: Image.Image,
|
| 170 |
+
input_image: Optional[Image.Image] = None,
|
| 171 |
+
ground_truth_mask: Optional[Image.Image] = None
|
| 172 |
+
) -> Dict[str, float]:
|
| 173 |
+
"""
|
| 174 |
+
Calculate all available quality metrics for an image.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
output_image: Processed image (typically RGBA after bg removal)
|
| 178 |
+
input_image: Original input image (for SSIM/PSNR comparison)
|
| 179 |
+
ground_truth_mask: Ground truth mask (if available, for accuracy)
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Dictionary of metric names to scores
|
| 183 |
+
"""
|
| 184 |
+
metrics = {}
|
| 185 |
+
|
| 186 |
+
# Edge quality (always available)
|
| 187 |
+
metrics['edge_quality'] = calculate_edge_quality(output_image)
|
| 188 |
+
|
| 189 |
+
# Transparency coverage (for RGBA images)
|
| 190 |
+
if output_image.mode == 'RGBA':
|
| 191 |
+
metrics['transparency_coverage'] = calculate_transparency_coverage(output_image)
|
| 192 |
+
|
| 193 |
+
# Comparison metrics (if input provided)
|
| 194 |
+
if input_image:
|
| 195 |
+
try:
|
| 196 |
+
metrics['ssim'] = calculate_ssim(input_image, output_image)
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"SSIM calculation failed: {e}")
|
| 199 |
+
metrics['ssim'] = 0.0
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
metrics['psnr'] = calculate_psnr(input_image, output_image)
|
| 203 |
+
except Exception as e:
|
| 204 |
+
print(f"PSNR calculation failed: {e}")
|
| 205 |
+
metrics['psnr'] = 0.0
|
| 206 |
+
|
| 207 |
+
# Mask accuracy (if ground truth provided)
|
| 208 |
+
if ground_truth_mask and output_image.mode == 'RGBA':
|
| 209 |
+
try:
|
| 210 |
+
# Extract alpha channel as predicted mask
|
| 211 |
+
alpha = output_image.split()[3]
|
| 212 |
+
metrics['mask_accuracy'] = calculate_mask_accuracy(alpha, ground_truth_mask)
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"Mask accuracy calculation failed: {e}")
|
| 215 |
+
metrics['mask_accuracy'] = 0.0
|
| 216 |
+
|
| 217 |
+
return metrics
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def calculate_weighted_quality_score(metrics: Dict[str, float]) -> float:
|
| 221 |
+
"""
|
| 222 |
+
Calculate overall weighted quality score from individual metrics.
|
| 223 |
+
|
| 224 |
+
Formula: (mask_accuracy × 0.6 + edge_quality × 0.4) × success_rate
|
| 225 |
+
|
| 226 |
+
If mask_accuracy unavailable, uses edge_quality only.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
metrics: Dictionary of metric scores
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Weighted quality score (0.0 to 1.0)
|
| 233 |
+
"""
|
| 234 |
+
edge_quality = metrics.get('edge_quality', 0.0)
|
| 235 |
+
mask_accuracy = metrics.get('mask_accuracy')
|
| 236 |
+
|
| 237 |
+
if mask_accuracy is not None:
|
| 238 |
+
# Use weighted combination
|
| 239 |
+
score = (mask_accuracy * 0.6 + edge_quality * 0.4)
|
| 240 |
+
else:
|
| 241 |
+
# Use edge quality only
|
| 242 |
+
score = edge_quality
|
| 243 |
+
|
| 244 |
+
return float(score)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
# Demo
|
| 249 |
+
print("=== Quality Metrics Module ===\n")
|
| 250 |
+
print("Available metrics:")
|
| 251 |
+
print(" - Edge Quality: Gradient-based edge sharpness")
|
| 252 |
+
print(" - Mask Accuracy: IoU with ground truth (if available)")
|
| 253 |
+
print(" - SSIM: Structural similarity (0-1)")
|
| 254 |
+
print(" - PSNR: Peak signal-to-noise ratio (dB)")
|
| 255 |
+
print(" - Transparency Coverage: % of transparent pixels")
|
| 256 |
+
print(" - Weighted Quality Score: Combined metric for ranking")
|
warp/inference/upscaler.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Upscaler Interface - Unified interface for image upscaling models.
|
| 3 |
+
|
| 4 |
+
Placeholder for Phase 3 upscaling integration.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
from dataclass import dataclass
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
from warp.models import get_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class UpscaleResult:
|
| 17 |
+
"""Result of an upscaling operation."""
|
| 18 |
+
|
| 19 |
+
output_image: Image.Image
|
| 20 |
+
model_name: str
|
| 21 |
+
processing_time_ms: int
|
| 22 |
+
success: bool = True
|
| 23 |
+
error_message: Optional[str] = None
|
| 24 |
+
input_size: tuple = (0, 0)
|
| 25 |
+
output_size: tuple = (0, 0)
|
| 26 |
+
scale_factor: int = 4
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class UpscalerEngine:
|
| 30 |
+
"""Engine for running upscaling models."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 33 |
+
"""
|
| 34 |
+
Initialize upscaler engine.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
api_key: API key for external services (if needed)
|
| 38 |
+
"""
|
| 39 |
+
self.api_key = api_key
|
| 40 |
+
# TODO: Initialize clients when upscaling models are added
|
| 41 |
+
|
| 42 |
+
def upscale(
|
| 43 |
+
self,
|
| 44 |
+
image: Image.Image,
|
| 45 |
+
model_name: Optional[str] = None,
|
| 46 |
+
scale_factor: int = 4
|
| 47 |
+
) -> UpscaleResult:
|
| 48 |
+
"""
|
| 49 |
+
Upscale an image using the specified model.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
image: Input PIL Image
|
| 53 |
+
model_name: Model name from registry (uses default if None)
|
| 54 |
+
scale_factor: Upscaling factor (2x, 4x, etc.)
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
UpscaleResult with output image and metrics
|
| 58 |
+
|
| 59 |
+
Raises:
|
| 60 |
+
ValueError: If model not found
|
| 61 |
+
NotImplementedError: If upscaling not yet implemented
|
| 62 |
+
"""
|
| 63 |
+
# Get model config
|
| 64 |
+
if model_name is None:
|
| 65 |
+
from warp.models import get_default_model
|
| 66 |
+
model_config = get_default_model("upscale")
|
| 67 |
+
if not model_config:
|
| 68 |
+
raise ValueError("No default upscaling model found")
|
| 69 |
+
else:
|
| 70 |
+
model_config = get_model(model_name)
|
| 71 |
+
if not model_config:
|
| 72 |
+
raise ValueError(f"Model '{model_name}' not found in registry")
|
| 73 |
+
|
| 74 |
+
# TODO: Implement actual upscaling in Phase 3
|
| 75 |
+
raise NotImplementedError(
|
| 76 |
+
f"Upscaling with {model_config.name} not yet implemented. "
|
| 77 |
+
"This will be added in Phase 3."
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Convenience function
|
| 82 |
+
def upscale_image(
|
| 83 |
+
image: Image.Image,
|
| 84 |
+
model_name: Optional[str] = None,
|
| 85 |
+
scale_factor: int = 4,
|
| 86 |
+
api_key: Optional[str] = None
|
| 87 |
+
) -> Image.Image:
|
| 88 |
+
"""
|
| 89 |
+
Quick upscaling function.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
image: Input PIL Image
|
| 93 |
+
model_name: Model name (uses default if None)
|
| 94 |
+
scale_factor: Upscaling factor
|
| 95 |
+
api_key: Optional API key
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Upscaled PIL Image
|
| 99 |
+
"""
|
| 100 |
+
engine = UpscalerEngine(api_key=api_key)
|
| 101 |
+
result = engine.upscale(image, model_name, scale_factor)
|
| 102 |
+
|
| 103 |
+
if not result.success:
|
| 104 |
+
raise RuntimeError(f"Upscaling failed: {result.error_message}")
|
| 105 |
+
|
| 106 |
+
return result.output_image
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
print("=== Upscaler Module ===\\n")
|
| 111 |
+
print("Status: Placeholder for Phase 3")
|
| 112 |
+
print("\\nPlanned upscaling models:")
|
| 113 |
+
print(" - Real-ESRGAN 4x")
|
| 114 |
+
print(" - GFPGAN (face restoration)")
|
| 115 |
+
print(" - CodeFormer (face enhancement)")
|
| 116 |
+
print(" - Swin2SR")
|
| 117 |
+
print("\\nUpscaling will be integrated in Phase 3 (Week 4-6)")
|
warp/models/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WARP Models - Model registry and configuration."""
|
| 2 |
+
|
| 3 |
+
from .registry import (
|
| 4 |
+
BACKGROUND_REMOVAL_MODELS,
|
| 5 |
+
MODEL_REGISTRY,
|
| 6 |
+
UPSCALING_MODELS,
|
| 7 |
+
ModelConfig,
|
| 8 |
+
get_default_model,
|
| 9 |
+
get_model,
|
| 10 |
+
get_models_by_type,
|
| 11 |
+
list_all_models,
|
| 12 |
+
list_model_names,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"ModelConfig",
|
| 17 |
+
"MODEL_REGISTRY",
|
| 18 |
+
"BACKGROUND_REMOVAL_MODELS",
|
| 19 |
+
"UPSCALING_MODELS",
|
| 20 |
+
"get_model",
|
| 21 |
+
"get_models_by_type",
|
| 22 |
+
"get_default_model",
|
| 23 |
+
"list_all_models",
|
| 24 |
+
"list_model_names",
|
| 25 |
+
]
|
warp/models/registry.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model Registry - Centralized configuration for all AI models in WARP.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class ModelConfig:
|
| 11 |
+
"""Configuration for an AI model."""
|
| 12 |
+
|
| 13 |
+
name: str
|
| 14 |
+
display_name: str
|
| 15 |
+
model_id: str # Hugging Face model ID or path
|
| 16 |
+
provider: str = "huggingface" # 'huggingface', 'replicate', 'local'
|
| 17 |
+
operation_type: str = "bg_removal" # 'bg_removal', 'upscale', 'color_correct'
|
| 18 |
+
description: str = ""
|
| 19 |
+
version: Optional[str] = None
|
| 20 |
+
|
| 21 |
+
# Model capabilities
|
| 22 |
+
input_formats: List[str] = field(default_factory=lambda: ["png", "jpg", "jpeg", "webp"])
|
| 23 |
+
output_format: str = "png"
|
| 24 |
+
max_input_size: int = 2048 # Max dimension in pixels
|
| 25 |
+
requires_gpu: bool = False
|
| 26 |
+
|
| 27 |
+
# Default parameters
|
| 28 |
+
default_parameters: Dict = field(default_factory=dict)
|
| 29 |
+
|
| 30 |
+
# Performance hints
|
| 31 |
+
estimated_time_ms: int = 3000
|
| 32 |
+
is_default: bool = False
|
| 33 |
+
is_active: bool = True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ============================================================================
|
| 37 |
+
# Background Removal Models
|
| 38 |
+
# ============================================================================
|
| 39 |
+
|
| 40 |
+
BACKGROUND_REMOVAL_MODELS = {
|
| 41 |
+
# Local models (using rembg)
|
| 42 |
+
"isnet-general": ModelConfig(
|
| 43 |
+
name="isnet-general",
|
| 44 |
+
display_name="ISNet General (Local)",
|
| 45 |
+
model_id="isnet-general-use",
|
| 46 |
+
provider="local",
|
| 47 |
+
operation_type="bg_removal",
|
| 48 |
+
description="Improved segmentation model - highest quality, slower",
|
| 49 |
+
estimated_time_ms=678, # Updated from benchmark results
|
| 50 |
+
is_default=False,
|
| 51 |
+
default_parameters={},
|
| 52 |
+
),
|
| 53 |
+
"u2net": ModelConfig(
|
| 54 |
+
name="u2net",
|
| 55 |
+
display_name="U2Net (Local)",
|
| 56 |
+
model_id="u2net",
|
| 57 |
+
provider="local",
|
| 58 |
+
operation_type="bg_removal",
|
| 59 |
+
description="General purpose background removal - BEST SPEED/QUALITY BALANCE",
|
| 60 |
+
estimated_time_ms=416, # Updated from benchmark results
|
| 61 |
+
is_default=True, # Updated based on Phase 2 benchmark results
|
| 62 |
+
default_parameters={},
|
| 63 |
+
),
|
| 64 |
+
"u2net-human": ModelConfig(
|
| 65 |
+
name="u2net-human",
|
| 66 |
+
display_name="U2Net Human (Local)",
|
| 67 |
+
model_id="u2net_human_seg",
|
| 68 |
+
provider="local",
|
| 69 |
+
operation_type="bg_removal",
|
| 70 |
+
description="Optimized for human segmentation and portraits",
|
| 71 |
+
estimated_time_ms=436, # Updated from benchmark results
|
| 72 |
+
is_default=False, # Changed: u2net outperformed in benchmarks
|
| 73 |
+
default_parameters={},
|
| 74 |
+
),
|
| 75 |
+
"isnet-anime": ModelConfig(
|
| 76 |
+
name="isnet-anime",
|
| 77 |
+
display_name="ISNet Anime (Local)",
|
| 78 |
+
model_id="isnet-anime",
|
| 79 |
+
provider="local",
|
| 80 |
+
operation_type="bg_removal",
|
| 81 |
+
description="Specialized for anime/illustrations - not suitable for medical photos",
|
| 82 |
+
estimated_time_ms=727, # Updated from benchmark results
|
| 83 |
+
is_default=False,
|
| 84 |
+
default_parameters={},
|
| 85 |
+
),
|
| 86 |
+
# HuggingFace models (currently unavailable)
|
| 87 |
+
"rmbg-1.4": ModelConfig(
|
| 88 |
+
name="rmbg-1.4",
|
| 89 |
+
display_name="RMBG 1.4 (HF)",
|
| 90 |
+
model_id="briaai/RMBG-1.4",
|
| 91 |
+
operation_type="bg_removal",
|
| 92 |
+
description="Fast and accurate background removal by Bria AI (v1.4) - Currently unavailable",
|
| 93 |
+
version="1.4",
|
| 94 |
+
estimated_time_ms=2000,
|
| 95 |
+
is_default=False,
|
| 96 |
+
is_active=False, # Disabled due to HF API deprecation
|
| 97 |
+
default_parameters={},
|
| 98 |
+
),
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# ============================================================================
|
| 102 |
+
# Upscaling Models
|
| 103 |
+
# ============================================================================
|
| 104 |
+
|
| 105 |
+
UPSCALING_MODELS = {
|
| 106 |
+
"realesrgan": ModelConfig(
|
| 107 |
+
name="realesrgan",
|
| 108 |
+
display_name="Real-ESRGAN 4x",
|
| 109 |
+
model_id="ai-forever/Real-ESRGAN",
|
| 110 |
+
operation_type="upscale",
|
| 111 |
+
description="Real-ESRGAN 4x upscaling for general images",
|
| 112 |
+
estimated_time_ms=8000,
|
| 113 |
+
is_default=True,
|
| 114 |
+
requires_gpu=True,
|
| 115 |
+
default_parameters={"scale": 4},
|
| 116 |
+
),
|
| 117 |
+
"gfpgan": ModelConfig(
|
| 118 |
+
name="gfpgan",
|
| 119 |
+
display_name="GFPGAN",
|
| 120 |
+
model_id="TencentARC/GFPGAN",
|
| 121 |
+
operation_type="upscale",
|
| 122 |
+
description="Face restoration and enhancement",
|
| 123 |
+
estimated_time_ms=7000,
|
| 124 |
+
is_default=False,
|
| 125 |
+
requires_gpu=True,
|
| 126 |
+
default_parameters={"version": "1.3"},
|
| 127 |
+
),
|
| 128 |
+
"codeformer": ModelConfig(
|
| 129 |
+
name="codeformer",
|
| 130 |
+
display_name="CodeFormer",
|
| 131 |
+
model_id="sczhou/CodeFormer",
|
| 132 |
+
operation_type="upscale",
|
| 133 |
+
description="Face restoration with better quality preservation",
|
| 134 |
+
estimated_time_ms=9000,
|
| 135 |
+
is_default=False,
|
| 136 |
+
requires_gpu=True,
|
| 137 |
+
default_parameters={"w": 0.5},
|
| 138 |
+
),
|
| 139 |
+
"swin2sr": ModelConfig(
|
| 140 |
+
name="swin2sr",
|
| 141 |
+
display_name="Swin2SR",
|
| 142 |
+
model_id="caidas/swin2SR-classical-sr-x4-64",
|
| 143 |
+
operation_type="upscale",
|
| 144 |
+
description="Swin Transformer for image super-resolution",
|
| 145 |
+
estimated_time_ms=6000,
|
| 146 |
+
is_default=False,
|
| 147 |
+
requires_gpu=True,
|
| 148 |
+
default_parameters={"scale": 4},
|
| 149 |
+
),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# ============================================================================
|
| 153 |
+
# Combined Registry
|
| 154 |
+
# ============================================================================
|
| 155 |
+
|
| 156 |
+
MODEL_REGISTRY = {
|
| 157 |
+
**BACKGROUND_REMOVAL_MODELS,
|
| 158 |
+
**UPSCALING_MODELS,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ============================================================================
|
| 163 |
+
# Helper Functions
|
| 164 |
+
# ============================================================================
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_model(name: str) -> Optional[ModelConfig]:
|
| 168 |
+
"""Get model configuration by name."""
|
| 169 |
+
return MODEL_REGISTRY.get(name)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def get_models_by_type(operation_type: str) -> Dict[str, ModelConfig]:
|
| 173 |
+
"""Get all models for a specific operation type."""
|
| 174 |
+
return {
|
| 175 |
+
name: config
|
| 176 |
+
for name, config in MODEL_REGISTRY.items()
|
| 177 |
+
if config.operation_type == operation_type and config.is_active
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def get_default_model(operation_type: str) -> Optional[ModelConfig]:
|
| 182 |
+
"""Get the default model for an operation type."""
|
| 183 |
+
models = get_models_by_type(operation_type)
|
| 184 |
+
for model in models.values():
|
| 185 |
+
if model.is_default:
|
| 186 |
+
return model
|
| 187 |
+
# If no default set, return first active model
|
| 188 |
+
return next(iter(models.values())) if models else None
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def list_all_models() -> Dict[str, ModelConfig]:
|
| 192 |
+
"""Get all active models."""
|
| 193 |
+
return {name: config for name, config in MODEL_REGISTRY.items() if config.is_active}
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def list_model_names(operation_type: Optional[str] = None) -> List[str]:
|
| 197 |
+
"""Get list of model names, optionally filtered by operation type."""
|
| 198 |
+
if operation_type:
|
| 199 |
+
return list(get_models_by_type(operation_type).keys())
|
| 200 |
+
return list(MODEL_REGISTRY.keys())
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# ============================================================================
|
| 204 |
+
# Model Information Display
|
| 205 |
+
# ============================================================================
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_model_info(name: str) -> str:
|
| 209 |
+
"""Get formatted information about a model."""
|
| 210 |
+
model = get_model(name)
|
| 211 |
+
if not model:
|
| 212 |
+
return f"Model '{name}' not found in registry."
|
| 213 |
+
|
| 214 |
+
info = f"""
|
| 215 |
+
Model: {model.display_name}
|
| 216 |
+
ID: {model.model_id}
|
| 217 |
+
Type: {model.operation_type}
|
| 218 |
+
Provider: {model.provider}
|
| 219 |
+
Description: {model.description}
|
| 220 |
+
Est. Time: {model.estimated_time_ms}ms
|
| 221 |
+
Default: {'Yes' if model.is_default else 'No'}
|
| 222 |
+
Active: {'Yes' if model.is_active else 'No'}
|
| 223 |
+
""".strip()
|
| 224 |
+
|
| 225 |
+
return info
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
# Demo/test code
|
| 230 |
+
print("=== WARP Model Registry ===\n")
|
| 231 |
+
|
| 232 |
+
print("Background Removal Models:")
|
| 233 |
+
for name in list_model_names("bg_removal"):
|
| 234 |
+
print(f" - {name}")
|
| 235 |
+
|
| 236 |
+
print("\nUpscaling Models:")
|
| 237 |
+
for name in list_model_names("upscale"):
|
| 238 |
+
print(f" - {name}")
|
| 239 |
+
|
| 240 |
+
print("\nDefault Background Removal Model:")
|
| 241 |
+
default_bg = get_default_model("bg_removal")
|
| 242 |
+
if default_bg:
|
| 243 |
+
print(f" {default_bg.display_name} ({default_bg.name})")
|
| 244 |
+
|
| 245 |
+
print("\nDefault Upscaling Model:")
|
| 246 |
+
default_up = get_default_model("upscale")
|
| 247 |
+
if default_up:
|
| 248 |
+
print(f" {default_up.display_name} ({default_up.name})")
|