nicka360 commited on
Commit
008c49a
·
1 Parent(s): b2e32b4

Restore lightweight warp package for HF Space runtime

Browse files
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # keep warp lightweight
3
+ warp/inference_outputs/
4
+ warp/data/
5
+ warp/**/checkpoints/
6
+ warp/**/*.pt
7
+ warp/**/*.pth
8
+ warp/**/*.ckpt
9
+ warp/**/*.safetensors
10
+ warp/**/*.bin
11
+ warp/**/*.npz
12
+ *.zip
13
+ *.7z
14
+ *.tar
warp/__init__.py ADDED
File without changes
warp/gradio_app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Gradio application for A360 WARP experimentation UI."""
warp/gradio_app/app-Nick.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING
5
+
6
+ import gradio as gr
7
+ from dotenv import load_dotenv
8
+
9
+ if TYPE_CHECKING:
10
+ from supabase import Client as ClientType
11
+
12
+ from warp.data import ImageLoader as ImageLoaderType
13
+ from warp.gradio_app.models.upscaler import ImageUpscaler
14
+ else:
15
+ ClientType = object
16
+ ImageLoaderType = object
17
+ ImageUpscaler = object
18
+
19
+ try:
20
+ from supabase import Client, create_client
21
+ except Exception:
22
+ create_client = None # type: ignore
23
+ Client = None # type: ignore
24
+
25
+ try:
26
+ from warp.data import ImageLoader
27
+ except ImportError:
28
+ ImageLoader = None # type: ignore
29
+
30
+ try:
31
+ from warp.gradio_app.models.upscaler import create_upscaler
32
+
33
+ UPSCALER_AVAILABLE = True
34
+ except ImportError:
35
+ create_upscaler = None # type: ignore
36
+ UPSCALER_AVAILABLE = False
37
+
38
+ load_dotenv()
39
+ SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
40
+ SUPABASE_ANON_KEY: str = os.getenv("SUPABASE_ANON_KEY", "")
41
+
42
+ supabase: ClientType | None = None
43
+ if callable(create_client) and SUPABASE_URL and SUPABASE_ANON_KEY:
44
+ supabase = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
45
+
46
+ # Initialize image loader
47
+ image_loader: ImageLoaderType | None = None
48
+ try:
49
+ if callable(ImageLoader):
50
+ image_loader = ImageLoader()
51
+ print(f"✓ Loaded {len(image_loader.practices)} practices with scraped images")
52
+ except Exception as e:
53
+ print(f"Warning: Could not initialize ImageLoader: {e}")
54
+
55
+ # Initialize upscaler (lazy load)
56
+ upscaler: ImageUpscaler | None = None
57
+
58
+
59
+ def load_practice_images(practice_name: str) -> tuple[list, str]:
60
+ """Load sample images from a practice.
61
+
62
+ Args:
63
+ practice_name: Name of the practice to load images from
64
+
65
+ Returns:
66
+ Tuple of (list of image paths, status message)
67
+ """
68
+ if not image_loader:
69
+ return [], "Image loader not available"
70
+
71
+ if not practice_name:
72
+ return [], "Please select a practice"
73
+
74
+ try:
75
+ # Get random sample of images
76
+ image_paths = image_loader.get_random_images(practice_name, n=10)
77
+ stats = image_loader.get_practice_stats(practice_name)
78
+ msg = (
79
+ f"Loaded {len(image_paths)} sample images from {practice_name} "
80
+ f"(Total: {stats['total_images']} images)"
81
+ )
82
+ return [str(p) for p in image_paths], msg
83
+ except Exception as e:
84
+ return [], f"Error loading images: {e}"
85
+
86
+
87
+ def run_model(procedure: str | None, notes: str | None) -> str:
88
+ """Run a placeholder model execution.
89
+
90
+ Args:
91
+ procedure: The selected procedure type
92
+ notes: Additional context or parameters
93
+
94
+ Returns:
95
+ A formatted string with procedure and notes information
96
+ """
97
+ return f"Procedure={procedure or 'n/a'} | Notes={notes or 'n/a'}"
98
+
99
+
100
+ def upscale_images(
101
+ before_img, after_img, prompt: str, num_steps: int, guidance: float
102
+ ) -> tuple[object, object, str]:
103
+ """Upscale before/after image pair.
104
+
105
+ Args:
106
+ before_img: Before image from Gradio
107
+ after_img: After image from Gradio
108
+ prompt: Quality prompt for upscaling
109
+ num_steps: Number of inference steps
110
+ guidance: Guidance scale
111
+
112
+ Returns:
113
+ Tuple of (upscaled_before, upscaled_after, status_message)
114
+ """
115
+ global upscaler
116
+
117
+ if not UPSCALER_AVAILABLE:
118
+ return (
119
+ None,
120
+ None,
121
+ "Upscaler not available. Install: pip install torch diffusers transformers",
122
+ )
123
+
124
+ if before_img is None or after_img is None:
125
+ return None, None, "Please upload both before and after images"
126
+
127
+ try:
128
+ # Lazy load upscaler
129
+ if upscaler is None and callable(create_upscaler):
130
+ upscaler = create_upscaler(model_type="sd-x4")
131
+
132
+ # Import PIL here to handle the images
133
+ from PIL import Image
134
+
135
+ # Convert Gradio images to PIL if needed
136
+ if not isinstance(before_img, Image.Image):
137
+ before_img = Image.fromarray(before_img)
138
+ if not isinstance(after_img, Image.Image):
139
+ after_img = Image.fromarray(after_img)
140
+
141
+ # Upscale the pair
142
+ before_upscaled, after_upscaled = upscaler.upscale_pair(
143
+ before_img,
144
+ after_img,
145
+ prompt=prompt,
146
+ num_inference_steps=num_steps,
147
+ guidance_scale=guidance,
148
+ )
149
+
150
+ return (
151
+ before_upscaled,
152
+ after_upscaled,
153
+ f"✓ Successfully upscaled images 4x (Original: {before_img.size} → Upscaled: {before_upscaled.size})",
154
+ )
155
+
156
+ except Exception as e:
157
+ return None, None, f"Error during upscaling: {str(e)}"
158
+
159
+
160
+ # Build the UI
161
+ with gr.Blocks(title="A360 WARP — Gradio") as demo:
162
+ gr.Markdown("# A360 WARP — Experimentation UI (MVP)")
163
+ gr.Markdown("Load and experiment with before/after images from scraped medical practices.")
164
+
165
+ # Practice selection and image loading
166
+ with gr.Tab("Image Browser"):
167
+ with gr.Row():
168
+ practice_dropdown = gr.Dropdown(
169
+ label="Select Practice",
170
+ choices=image_loader.practices if image_loader else [],
171
+ value=None,
172
+ )
173
+ load_btn = gr.Button("Load Sample Images", variant="primary")
174
+
175
+ status_text = gr.Textbox(label="Status", interactive=False)
176
+ image_gallery = gr.Gallery(label="Sample Images", show_label=True, columns=5, height="auto")
177
+
178
+ load_btn.click(
179
+ fn=load_practice_images,
180
+ inputs=[practice_dropdown],
181
+ outputs=[image_gallery, status_text],
182
+ )
183
+
184
+ # Image Enhancement (Upscaling)
185
+ with gr.Tab("Image Enhancement"):
186
+ gr.Markdown(
187
+ "### Upscale Before/After Images\n"
188
+ "Upload medical before/after photos to upscale them 4x using AI. "
189
+ "This improves image quality and detail for better comparison."
190
+ )
191
+
192
+ with gr.Row():
193
+ with gr.Column():
194
+ gr.Markdown("#### Original Images")
195
+ before_input = gr.Image(label="Before Image", type="numpy")
196
+ after_input = gr.Image(label="After Image", type="numpy")
197
+
198
+ with gr.Column():
199
+ gr.Markdown("#### Upscaled Images (4x)")
200
+ before_output = gr.Image(label="Upscaled Before")
201
+ after_output = gr.Image(label="Upscaled After")
202
+
203
+ with gr.Row():
204
+ with gr.Column():
205
+ prompt_input = gr.Textbox(
206
+ label="Quality Prompt",
207
+ value="high quality medical photography, sharp details, professional lighting",
208
+ placeholder="Describe desired image quality...",
209
+ )
210
+ with gr.Column():
211
+ num_steps = gr.Slider(
212
+ minimum=20,
213
+ maximum=100,
214
+ value=50,
215
+ step=5,
216
+ label="Inference Steps (higher = better quality, slower)",
217
+ )
218
+ guidance_scale = gr.Slider(
219
+ minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale"
220
+ )
221
+
222
+ upscale_btn = gr.Button("Upscale Images", variant="primary", size="lg")
223
+ upscale_status = gr.Textbox(label="Status", interactive=False)
224
+
225
+ upscale_btn.click(
226
+ fn=upscale_images,
227
+ inputs=[before_input, after_input, prompt_input, num_steps, guidance_scale],
228
+ outputs=[before_output, after_output, upscale_status],
229
+ )
230
+
231
+ # Model experimentation
232
+ with gr.Tab("Model Experiments"):
233
+ with gr.Row():
234
+ procedure = gr.Dropdown(
235
+ label="Procedure",
236
+ choices=[
237
+ "breast-augmentation",
238
+ "liposuction",
239
+ "rhinoplasty",
240
+ "ftm-top-surgery",
241
+ "coolsculpting",
242
+ ],
243
+ value=None,
244
+ )
245
+ notes = gr.Textbox(label="Notes", placeholder="Run context / params…")
246
+ run = gr.Button("Run")
247
+ out = gr.Textbox(label="Output")
248
+
249
+ run.click(run_model, inputs=[procedure, notes], outputs=out)
250
+
251
+ if __name__ == "__main__":
252
+ demo.launch()
warp/gradio_app/app.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING
5
+
6
+ import gradio as gr
7
+ from dotenv import load_dotenv
8
+
9
+ if TYPE_CHECKING:
10
+ from supabase import Client as ClientType
11
+
12
+ from warp.data import ImageLoader as ImageLoaderType
13
+ from warp.gradio_app.models.upscaler import ImageUpscaler
14
+ else:
15
+ ClientType = object
16
+ ImageLoaderType = object
17
+ ImageUpscaler = object
18
+
19
+ try:
20
+ from supabase import Client, create_client
21
+ except Exception:
22
+ create_client = None # type: ignore
23
+ Client = None # type: ignore
24
+
25
+ try:
26
+ from warp.data import ImageLoader
27
+ except ImportError:
28
+ ImageLoader = None # type: ignore
29
+
30
+ try:
31
+ from warp.gradio_app.models.upscaler import create_upscaler
32
+
33
+ UPSCALER_AVAILABLE = True
34
+ print("✓ Upscaler module loaded successfully")
35
+ except ImportError as e:
36
+ create_upscaler = None # type: ignore
37
+ UPSCALER_AVAILABLE = False
38
+ print(f"✗ Upscaler import failed: {e}")
39
+
40
+ # Temporarily disable Advanced Upscaling tab due to import issues
41
+ # Will re-enable after fixing module resolution
42
+ COMPARE_TAB_AVAILABLE = False
43
+ build_upscale_compare = None
44
+
45
+ load_dotenv()
46
+ SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
47
+ SUPABASE_ANON_KEY: str = os.getenv("SUPABASE_ANON_KEY", "")
48
+
49
+ supabase: ClientType | None = None
50
+ if callable(create_client) and SUPABASE_URL and SUPABASE_ANON_KEY:
51
+ supabase = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
52
+
53
+ # Initialize image loader
54
+ image_loader: ImageLoaderType | None = None
55
+ try:
56
+ if callable(ImageLoader):
57
+ image_loader = ImageLoader()
58
+ print(f"✓ Loaded {len(image_loader.practices)} practices with scraped images")
59
+ except Exception as e:
60
+ # If initialization fails for any reason (including mocked import errors),
61
+ # fall back to no image loader so the rest of the app can still import.
62
+ image_loader = None
63
+ print(f"Warning: Could not initialize ImageLoader: {e}")
64
+
65
+ # Initialize upscaler (lazy load)
66
+ upscaler: ImageUpscaler | None = None
67
+
68
+
69
+ def load_practice_images(practice_name: str) -> tuple[list, str]:
70
+ """Load sample images from a practice.
71
+
72
+ Args:
73
+ practice_name: Name of the practice to load images from
74
+
75
+ Returns:
76
+ Tuple of (list of image paths, status message)
77
+ """
78
+ if not image_loader:
79
+ return [], "Image loader not available"
80
+
81
+ if not practice_name:
82
+ return [], "Please select a practice"
83
+
84
+ try:
85
+ # Get random sample of images
86
+ image_paths = image_loader.get_random_images(practice_name, n=10)
87
+ stats = image_loader.get_practice_stats(practice_name)
88
+ msg = (
89
+ f"Loaded {len(image_paths)} sample images from {practice_name} "
90
+ f"(Total: {stats['total_images']} images)"
91
+ )
92
+ return [str(p) for p in image_paths], msg
93
+ except Exception as e:
94
+ return [], f"Error loading images: {e}"
95
+
96
+
97
+ def run_model(procedure: str | None, notes: str | None) -> str:
98
+ """Run a placeholder model execution.
99
+
100
+ Args:
101
+ procedure: The selected procedure type
102
+ notes: Additional context or parameters
103
+
104
+ Returns:
105
+ A formatted string with procedure and notes information
106
+ """
107
+ return f"Procedure={procedure or 'n/a'} | Notes={notes or 'n/a'}"
108
+
109
+
110
+ def upscale_images(
111
+ before_img, after_img, prompt: str, num_steps: int, guidance: float, progress=gr.Progress()
112
+ ):
113
+ """Upscale before/after image pair (synchronous helper).
114
+
115
+ This version is a simple function (not a generator) so tests can call it
116
+ and make assertions about the returned tuple. The Gradio UI wraps this in
117
+ a streaming function that updates the progress bar and status text.
118
+
119
+ Returns:
120
+ Tuple of (upscaled_before, upscaled_after, status_message)
121
+ """
122
+ global upscaler
123
+
124
+ # Handle missing upscaler dependency
125
+ if not UPSCALER_AVAILABLE:
126
+ return (
127
+ None,
128
+ None,
129
+ "Upscaler not available. Install: pip install torch diffusers transformers",
130
+ )
131
+
132
+ # Validate inputs
133
+ if before_img is None or after_img is None:
134
+ return None, None, "Please upload both before and after images"
135
+
136
+ try:
137
+ # Lazy load upscaler on first use
138
+ if upscaler is None and callable(create_upscaler):
139
+ upscaler = create_upscaler(model_type="sd-x4")
140
+
141
+ # Import PIL here to handle the images
142
+ from PIL import Image
143
+
144
+ # Convert numpy arrays to PIL Images if needed
145
+ if not isinstance(before_img, Image.Image):
146
+ before_img = Image.fromarray(before_img)
147
+ if not isinstance(after_img, Image.Image):
148
+ after_img = Image.fromarray(after_img)
149
+
150
+ orig_size = before_img.size
151
+
152
+ # Use the pair upscaling helper with a callback that updates the
153
+ # Gradio progress bar more granularly during diffusion steps.
154
+ callback_state = {"phase": "before", "last_step": -1}
155
+
156
+ def progress_callback(step, timestep, latents): # type: ignore[unused-argument]
157
+ """Update progress bar for each diffusion step.
158
+
159
+ We see steps 0..num_steps-1 for the "before" image first, then
160
+ again for the "after" image. When the step counter resets, we
161
+ switch to the "after" phase and map progress into [0.5, 0.9].
162
+ """
163
+ try:
164
+ # Detect phase change when step counter resets
165
+ if step < callback_state["last_step"]:
166
+ callback_state["phase"] = "after"
167
+ callback_state["last_step"] = step
168
+
169
+ frac = step / max(num_steps, 1)
170
+ if callback_state["phase"] == "before":
171
+ # Map to [0.1, 0.5]
172
+ pct = 0.1 + 0.4 * frac
173
+ desc = f"Upscaling BEFORE image ({step}/{num_steps})"
174
+ else:
175
+ # Map to [0.5, 0.9]
176
+ pct = 0.5 + 0.4 * frac
177
+ desc = f"Upscaling AFTER image ({step}/{num_steps})"
178
+
179
+ try:
180
+ progress(pct, desc=desc)
181
+ except Exception:
182
+ # In tests or non-Gradio contexts, progress may be a no-op
183
+ pass
184
+ except Exception:
185
+ # Never allow progress UI issues to break the core upscaling
186
+ pass
187
+
188
+ before_upscaled, after_upscaled = upscaler.upscale_pair(
189
+ before_img,
190
+ after_img,
191
+ prompt=prompt,
192
+ num_inference_steps=num_steps,
193
+ guidance_scale=guidance,
194
+ callback=progress_callback,
195
+ callback_steps=1,
196
+ )
197
+
198
+ status = (
199
+ "Successfully upscaled both images 4x\n"
200
+ f"Original: {orig_size[0]}×{orig_size[1]} → "
201
+ f"Upscaled: {before_upscaled.size[0]}×{before_upscaled.size[1]}"
202
+ )
203
+ return before_upscaled, after_upscaled, status
204
+
205
+ except Exception as e:
206
+ # Graceful error handling for tests and UI
207
+ return None, None, f"Error during upscaling: {str(e)}"
208
+
209
+
210
+ def upscale_images_stream(
211
+ before_img, after_img, prompt: str, num_steps: int, guidance: float, progress=gr.Progress()
212
+ ):
213
+ """Streaming wrapper for ``upscale_images`` used by the Gradio UI.
214
+
215
+ Yields intermediate status updates so the user sees a live progress bar
216
+ and status text while the heavy model runs.
217
+ """
218
+ # Handle missing upscaler dependency
219
+ if not UPSCALER_AVAILABLE:
220
+ yield (
221
+ None,
222
+ None,
223
+ "Upscaler not available. Install: pip install torch diffusers transformers",
224
+ )
225
+ return
226
+
227
+ # Validate inputs
228
+ if before_img is None or after_img is None:
229
+ yield None, None, "Please upload both before and after images"
230
+ return
231
+
232
+ try:
233
+ # Initial progress
234
+ try:
235
+ progress(0.0, desc="Initializing upscaler...")
236
+ except Exception:
237
+ pass
238
+ yield None, None, "Initializing upscaler..."
239
+
240
+ # Coarse progress while running the model
241
+ try:
242
+ progress(0.3, desc="Upscaling images...")
243
+ except Exception:
244
+ pass
245
+
246
+ before_upscaled, after_upscaled, status = upscale_images(
247
+ before_img, after_img, prompt, num_steps, guidance, progress
248
+ )
249
+
250
+ try:
251
+ progress(1.0, desc="Complete")
252
+ except Exception:
253
+ pass
254
+
255
+ yield before_upscaled, after_upscaled, status
256
+
257
+ except Exception as e:
258
+ yield None, None, f"Error during upscaling: {str(e)}"
259
+
260
+
261
+ # Build the UI
262
+ with gr.Blocks(title="A360 WARP — Gradio") as demo:
263
+ gr.Markdown("# A360 WARP — Experimentation UI (MVP)")
264
+ gr.Markdown("Load and experiment with before/after images from scraped medical practices.")
265
+
266
+ # Practice selection and image loading
267
+ with gr.Tab("Image Browser"):
268
+ with gr.Row():
269
+ practice_dropdown = gr.Dropdown(
270
+ label="Select Practice",
271
+ choices=image_loader.practices if image_loader else [],
272
+ value=None,
273
+ )
274
+ load_btn = gr.Button("Load Sample Images", variant="primary")
275
+
276
+ status_text = gr.Textbox(label="Status", interactive=False)
277
+ image_gallery = gr.Gallery(label="Sample Images", show_label=True, columns=5, height="auto")
278
+
279
+ load_btn.click(
280
+ fn=load_practice_images,
281
+ inputs=[practice_dropdown],
282
+ outputs=[image_gallery, status_text],
283
+ )
284
+
285
+ # Image Enhancement (Upscaling)
286
+ with gr.Tab("Image Enhancement"):
287
+ gr.Markdown(
288
+ "### Upscale Before/After Images\n"
289
+ "Upload medical before/after photos to upscale them 4x using AI. "
290
+ "This improves image quality and detail for better comparison."
291
+ )
292
+
293
+ with gr.Row():
294
+ with gr.Column():
295
+ gr.Markdown("#### Original Images")
296
+ before_input = gr.Image(label="Before Image", type="numpy")
297
+ after_input = gr.Image(label="After Image", type="numpy")
298
+
299
+ with gr.Column():
300
+ gr.Markdown("#### Upscaled Images (4x)")
301
+ before_output = gr.Image(label="Upscaled Before")
302
+ after_output = gr.Image(label="Upscaled After")
303
+
304
+ with gr.Row():
305
+ with gr.Column():
306
+ prompt_input = gr.Textbox(
307
+ label="Quality Prompt",
308
+ value="high quality medical photography, sharp details, professional lighting",
309
+ placeholder="Describe desired image quality...",
310
+ )
311
+ with gr.Column():
312
+ num_steps = gr.Slider(
313
+ minimum=20,
314
+ maximum=100,
315
+ value=50,
316
+ step=5,
317
+ label="Inference Steps (higher = better quality, slower)",
318
+ )
319
+ guidance_scale = gr.Slider(
320
+ minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale"
321
+ )
322
+
323
+ upscale_btn = gr.Button("Upscale Images", variant="primary", size="lg")
324
+ upscale_status = gr.Textbox(label="Status", interactive=False)
325
+
326
+ # Use the streaming wrapper so users see live progress/status updates
327
+ upscale_btn.click(
328
+ fn=upscale_images_stream,
329
+ inputs=[before_input, after_input, prompt_input, num_steps, guidance_scale],
330
+ outputs=[before_output, after_output, upscale_status],
331
+ )
332
+
333
+ # Advanced Upscaling with Comparison
334
+ if COMPARE_TAB_AVAILABLE and build_upscale_compare:
335
+ with gr.Tab("Advanced Upscaling"):
336
+ build_upscale_compare()
337
+
338
+ # Model experimentation
339
+ with gr.Tab("Model Experiments"):
340
+ with gr.Row():
341
+ procedure = gr.Dropdown(
342
+ label="Procedure",
343
+ choices=[
344
+ "breast-augmentation",
345
+ "liposuction",
346
+ "rhinoplasty",
347
+ "ftm-top-surgery",
348
+ "coolsculpting",
349
+ ],
350
+ value=None,
351
+ )
352
+ notes = gr.Textbox(label="Notes", placeholder="Run context / params…")
353
+ run = gr.Button("Run")
354
+ out = gr.Textbox(label="Output")
355
+
356
+ run.click(run_model, inputs=[procedure, notes], outputs=out)
357
+
358
+ if __name__ == "__main__":
359
+ demo.launch()
warp/gradio_app/model_comparison.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Comparison App - Gradio interface for testing and comparing background removal models.
3
+ """
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import gradio as gr
10
+ import pandas as pd
11
+ from dotenv import load_dotenv
12
+ from PIL import Image
13
+
14
+ # Load environment variables from .env file
15
+ env_path = Path(__file__).parent.parent.parent / ".env"
16
+ load_dotenv(env_path)
17
+
18
+ from warp.data import ImageLoader
19
+ from warp.inference.background_removal import BackgroundRemovalEngine, BackgroundRemovalResult
20
+ from warp.models import list_model_names
21
+
22
+ # Initialize components
23
+ engine = BackgroundRemovalEngine()
24
+ loader = ImageLoader()
25
+
26
+
27
+ def load_test_image(practice: str) -> Optional[Image.Image]:
28
+ """Load a random test image from a practice."""
29
+ try:
30
+ images = loader.get_random_images(practice, n=1)
31
+ if images:
32
+ return loader.load_image(images[0])
33
+ except Exception as e:
34
+ print(f"Error loading image: {e}")
35
+ return None
36
+
37
+
38
+ def run_comparison(
39
+ image: Image.Image, selected_models: List[str]
40
+ ) -> Tuple[Dict, pd.DataFrame, str]:
41
+ """
42
+ Run background removal comparison across selected models.
43
+
44
+ Args:
45
+ image: Input PIL Image
46
+ selected_models: List of model names to test
47
+
48
+ Returns:
49
+ Tuple of (results_dict, metrics_df, summary_text)
50
+ """
51
+ if not image:
52
+ return {}, pd.DataFrame(), "❌ No image provided"
53
+
54
+ if not selected_models:
55
+ return {}, pd.DataFrame(), "❌ No models selected"
56
+
57
+ print(f"\n{'='*60}")
58
+ print(f"Running comparison with {len(selected_models)} models...")
59
+ print(f"{'='*60}")
60
+
61
+ # Run comparisons
62
+ results = {}
63
+ for model_name in selected_models:
64
+ result = engine.remove_background(image, model_name)
65
+ results[model_name] = result
66
+
67
+ # Prepare outputs for gallery (list of tuples with image and caption)
68
+ output_images = [(res.output_image, f"{name}\n{res.processing_time_ms}ms")
69
+ for name, res in results.items() if res.success]
70
+
71
+ # Create metrics DataFrame
72
+ metrics_data = []
73
+ for name, result in results.items():
74
+ metrics_data.append(
75
+ {
76
+ "Model": name,
77
+ "Status": "✓ Success" if result.success else "✗ Failed",
78
+ "Time (ms)": result.processing_time_ms,
79
+ "Edge Quality": f"{result.edge_quality:.3f}" if result.edge_quality else "-",
80
+ "SSIM": f"{result.ssim:.3f}" if result.ssim else "-",
81
+ "PSNR (dB)": f"{result.psnr:.1f}" if result.psnr else "-",
82
+ "Transparency": f"{result.transparency_coverage:.1%}" if result.transparency_coverage else "-",
83
+ "Quality Score": f"{result.weighted_quality_score:.3f}" if result.weighted_quality_score else "-",
84
+ "Error": result.error_message or "-",
85
+ }
86
+ )
87
+
88
+ metrics_df = pd.DataFrame(metrics_data)
89
+
90
+ # Create summary
91
+ successful = sum(1 for r in results.values() if r.success)
92
+ avg_time = (
93
+ sum(r.processing_time_ms for r in results.values() if r.success) / successful
94
+ if successful > 0
95
+ else 0
96
+ )
97
+
98
+ summary = f"""
99
+ ## Comparison Summary
100
+
101
+ - **Models Tested**: {len(selected_models)}
102
+ - **Successful**: {successful}
103
+ - **Failed**: {len(selected_models) - successful}
104
+ - **Average Time**: {avg_time:.0f}ms
105
+ """.strip()
106
+
107
+ return output_images, metrics_df, summary
108
+
109
+
110
+ def run_single_model(
111
+ image: Image.Image, model_name: str
112
+ ) -> Tuple[Optional[Image.Image], str]:
113
+ """Run a single model and return output + info."""
114
+ if not image:
115
+ return None, "❌ No image provided"
116
+
117
+ result = engine.remove_background(image, model_name)
118
+
119
+ if result.success:
120
+ info = f"""
121
+ ### ✓ Success
122
+
123
+ - **Model**: {result.model_name}
124
+ - **Processing Time**: {result.processing_time_ms}ms
125
+ - **Input Size**: {result.input_size[0]}x{result.input_size[1]}
126
+ - **Output Size**: {result.output_size[0]}x{result.output_size[1]}
127
+ """.strip()
128
+ return result.output_image, info
129
+ else:
130
+ info = f"""
131
+ ### ✗ Failed
132
+
133
+ - **Model**: {result.model_name}
134
+ - **Error**: {result.error_message}
135
+ - **Time Elapsed**: {result.processing_time_ms}ms
136
+ """.strip()
137
+ return image, info
138
+
139
+
140
+ # ============================================================================
141
+ # Gradio Interface
142
+ # ============================================================================
143
+
144
+
145
+ def create_comparison_tab():
146
+ """Create the model comparison tab."""
147
+ with gr.Column():
148
+ gr.Markdown("# 🔬 Model Comparison")
149
+ gr.Markdown("Compare multiple background removal models side-by-side")
150
+
151
+ with gr.Row():
152
+ with gr.Column(scale=1):
153
+ input_image = gr.Image(
154
+ type="pil", label="Input Image", height=400
155
+ )
156
+
157
+ # Model selection
158
+ available_models = list_model_names("bg_removal")
159
+ model_selector = gr.CheckboxGroup(
160
+ choices=available_models,
161
+ value=[available_models[0]] if available_models else [],
162
+ label="Select Models to Compare",
163
+ )
164
+
165
+ # Quick load options
166
+ with gr.Row():
167
+ practice_dropdown = gr.Dropdown(
168
+ choices=loader.practices if hasattr(loader, "practices") else [],
169
+ label="Load Random Image From",
170
+ value=None,
171
+ )
172
+ load_btn = gr.Button("📁 Load Sample", size="sm")
173
+
174
+ run_btn = gr.Button("▶️ Run Comparison", variant="primary", size="lg")
175
+
176
+ with gr.Column(scale=2):
177
+ summary_md = gr.Markdown("### Ready to compare models")
178
+ metrics_table = gr.DataFrame(label="Performance Metrics")
179
+
180
+ # Output gallery
181
+ gr.Markdown("### Model Outputs")
182
+ output_gallery = gr.Gallery(
183
+ label="Results", columns=3, height="auto", object_fit="contain"
184
+ )
185
+
186
+ # Wire up events
187
+ def load_sample(practice):
188
+ if practice:
189
+ img = load_test_image(practice)
190
+ return img
191
+ return None
192
+
193
+ load_btn.click(fn=load_sample, inputs=[practice_dropdown], outputs=[input_image])
194
+
195
+ run_btn.click(
196
+ fn=run_comparison,
197
+ inputs=[input_image, model_selector],
198
+ outputs=[output_gallery, metrics_table, summary_md],
199
+ )
200
+
201
+
202
+ def create_single_model_tab():
203
+ """Create the single model testing tab."""
204
+ with gr.Column():
205
+ gr.Markdown("# 🎯 Single Model Test")
206
+ gr.Markdown("Test a single model with detailed results")
207
+
208
+ with gr.Row():
209
+ with gr.Column(scale=1):
210
+ input_image = gr.Image(type="pil", label="Input Image")
211
+
212
+ model_dropdown = gr.Dropdown(
213
+ choices=list_model_names("bg_removal"),
214
+ value=list_model_names("bg_removal")[0],
215
+ label="Select Model",
216
+ )
217
+
218
+ run_btn = gr.Button("▶️ Remove Background", variant="primary")
219
+
220
+ with gr.Column(scale=1):
221
+ output_image = gr.Image(type="pil", label="Output Image")
222
+ result_info = gr.Markdown("### Waiting for input...")
223
+
224
+ run_btn.click(
225
+ fn=run_single_model,
226
+ inputs=[input_image, model_dropdown],
227
+ outputs=[output_image, result_info],
228
+ )
229
+
230
+
231
+ def create_app():
232
+ """Create the full Gradio app."""
233
+ with gr.Blocks(title="WARP Model Comparison", theme=gr.themes.Soft()) as app:
234
+ gr.Markdown(
235
+ """
236
+ # 🚀 WARP Model Test Harness
237
+
238
+ **AI-Powered Image Processing Pipeline**
239
+
240
+ Test and compare background removal models with real-time performance metrics.
241
+ """
242
+ )
243
+
244
+ with gr.Tabs():
245
+ with gr.Tab("Model Comparison"):
246
+ create_comparison_tab()
247
+
248
+ with gr.Tab("Single Model Test"):
249
+ create_single_model_tab()
250
+
251
+ with gr.Tab("Model Registry"):
252
+ gr.Markdown("## 📚 Available Models")
253
+
254
+ # Display model registry
255
+ from warp.models import get_models_by_type
256
+
257
+ bg_models = get_models_by_type("bg_removal")
258
+ model_info_data = []
259
+
260
+ for name, config in bg_models.items():
261
+ model_info_data.append(
262
+ {
263
+ "Name": config.name,
264
+ "Display Name": config.display_name,
265
+ "Model ID": config.model_id,
266
+ "Est. Time": f"{config.estimated_time_ms}ms",
267
+ "Default": "✓" if config.is_default else "",
268
+ "Description": config.description,
269
+ }
270
+ )
271
+
272
+ gr.DataFrame(value=pd.DataFrame(model_info_data), label="Background Removal Models")
273
+
274
+ return app
275
+
276
+
277
+ # ============================================================================
278
+ # Main Entry Point
279
+ # ============================================================================
280
+
281
+ if __name__ == "__main__":
282
+ app = create_app()
283
+ app.launch(
284
+ server_name="127.0.0.1", # Bind to localhost so the browser URL is valid on Windows
285
+ server_port=7860,
286
+ share=False,
287
+ show_error=True,
288
+ )
warp/gradio_app/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Model registry for A360 WARP."""
2
+
3
+ from .registry import MODELS
4
+
5
+ __all__ = ["MODELS"]
warp/gradio_app/models/registry-Nick.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ MODELS = {
2
+ "CLIP": "openai/clip-vit-base-patch32",
3
+ "BLIP-2": "Salesforce/blip2-flan-t5-xl",
4
+ "DINOv2": "facebook/dinov2-base",
5
+ }
6
+
7
+ UPSCALER_MODELS = {
8
+ "SD-X4-Upscaler": "stabilityai/stable-diffusion-x4-upscaler",
9
+ }
warp/gradio_app/models/registry.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ MODELS = {
2
+ "CLIP": "openai/clip-vit-base-patch32",
3
+ "BLIP-2": "Salesforce/blip2-flan-t5-xl",
4
+ "DINOv2": "facebook/dinov2-base",
5
+ }
warp/gradio_app/models/upscaler.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image upscaling using HuggingFace models."""
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Literal
5
+
6
+ from PIL import Image
7
+
8
+ if TYPE_CHECKING:
9
+ from diffusers import StableDiffusionUpscalePipeline as PipelineType
10
+ else:
11
+ PipelineType = object
12
+
13
+ try:
14
+ import torch
15
+ from diffusers import StableDiffusionUpscalePipeline
16
+
17
+ TORCH_AVAILABLE = True
18
+ except ImportError:
19
+ # If either torch or diffusers is missing, mark them unavailable. The
20
+ # tests patch these symbols as needed, and the runtime gracefully degrades
21
+ # by raising a clear ImportError from ImageUpscaler.__init__.
22
+ torch = None # type: ignore[assignment]
23
+ StableDiffusionUpscalePipeline = None # type: ignore[assignment]
24
+ TORCH_AVAILABLE = False
25
+
26
+
27
+ class ImageUpscaler:
28
+ """Handle image upscaling using HuggingFace models."""
29
+
30
+ def __init__(
31
+ self, model_id: str = "stabilityai/stable-diffusion-x4-upscaler", device: str | None = None
32
+ ):
33
+ """Initialize the upscaler.
34
+
35
+ Args:
36
+ model_id: HuggingFace model identifier
37
+ device: Device to run model on ('cuda', 'cpu', or None for auto)
38
+ """
39
+ if not TORCH_AVAILABLE:
40
+ raise ImportError(
41
+ "torch and diffusers are required for upscaling. "
42
+ "Install with: pip install torch diffusers transformers"
43
+ )
44
+
45
+ self.model_id = model_id
46
+
47
+ # Auto-detect device
48
+ if device is None:
49
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ else:
51
+ self.device = device
52
+
53
+ self.pipeline: PipelineType | None = None
54
+ self._load_model()
55
+
56
+ def _load_model(self) -> None:
57
+ """Load the upscaling model."""
58
+ print(f"Loading upscaler model: {self.model_id} on {self.device}...")
59
+
60
+ # Determine torch dtype based on device
61
+ torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
62
+
63
+ self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
64
+ self.model_id, torch_dtype=torch_dtype
65
+ )
66
+ self.pipeline = self.pipeline.to(self.device)
67
+
68
+ # Enable memory optimizations if on CUDA
69
+ if self.device == "cuda":
70
+ self.pipeline.enable_attention_slicing()
71
+
72
+ print(f"✓ Model loaded successfully on {self.device}")
73
+
74
+ def upscale(
75
+ self,
76
+ image: Image.Image | str | Path,
77
+ prompt: str = "high quality, detailed, sharp",
78
+ num_inference_steps: int = 50,
79
+ guidance_scale: float = 7.5,
80
+ callback=None,
81
+ callback_steps: int = 1,
82
+ ) -> Image.Image:
83
+ """Upscale an image 4x.
84
+
85
+ Args:
86
+ image: PIL Image or path to image file
87
+ prompt: Text prompt to guide upscaling (helps with quality)
88
+ num_inference_steps: Number of denoising steps (higher = better quality, slower)
89
+ guidance_scale: How closely to follow the prompt (7.5 is good default)
90
+ callback: Optional callback function(step, timestep, latents) called each step
91
+ callback_steps: How often to call the callback (default: every step)
92
+
93
+ Returns:
94
+ Upscaled PIL Image
95
+ """
96
+ # Load image if path is provided
97
+ if isinstance(image, (str, Path)):
98
+ image = Image.open(image).convert("RGB")
99
+
100
+ # Ensure RGB mode
101
+ if image.mode != "RGB":
102
+ image = image.convert("RGB")
103
+
104
+ # Run upscaling
105
+ if self.pipeline is None:
106
+ raise RuntimeError("Pipeline not initialized")
107
+ result = self.pipeline(
108
+ prompt=prompt,
109
+ image=image,
110
+ num_inference_steps=num_inference_steps,
111
+ guidance_scale=guidance_scale,
112
+ callback=callback,
113
+ callback_steps=callback_steps,
114
+ )
115
+ upscaled: Image.Image = result.images[0]
116
+
117
+ return upscaled
118
+
119
+ def upscale_pair(
120
+ self,
121
+ before_image: Image.Image | str | Path,
122
+ after_image: Image.Image | str | Path,
123
+ prompt: str = "high quality medical photography, sharp details, professional lighting",
124
+ **kwargs,
125
+ ) -> tuple[Image.Image, Image.Image]:
126
+ """Upscale a before/after image pair.
127
+
128
+ Args:
129
+ before_image: Before image (PIL Image or path)
130
+ after_image: After image (PIL Image or path)
131
+ prompt: Text prompt for upscaling quality
132
+ **kwargs: Additional arguments for upscale()
133
+
134
+ Returns:
135
+ Tuple of (upscaled_before, upscaled_after)
136
+ """
137
+ print("Upscaling before image...")
138
+ before_upscaled = self.upscale(before_image, prompt=prompt, **kwargs)
139
+
140
+ print("Upscaling after image...")
141
+ after_upscaled = self.upscale(after_image, prompt=prompt, **kwargs)
142
+
143
+ return before_upscaled, after_upscaled
144
+
145
+ def batch_upscale(
146
+ self,
147
+ images: list[Image.Image | str | Path],
148
+ prompt: str = "high quality, detailed, sharp",
149
+ **kwargs,
150
+ ) -> list[Image.Image]:
151
+ """Upscale multiple images.
152
+
153
+ Args:
154
+ images: List of PIL Images or paths
155
+ prompt: Text prompt for upscaling
156
+ **kwargs: Additional arguments for upscale()
157
+
158
+ Returns:
159
+ List of upscaled PIL Images
160
+ """
161
+ results = []
162
+ for i, img in enumerate(images, 1):
163
+ print(f"Upscaling image {i}/{len(images)}...")
164
+ upscaled = self.upscale(img, prompt=prompt, **kwargs)
165
+ results.append(upscaled)
166
+ return results
167
+
168
+
169
+ def create_upscaler(
170
+ model_type: Literal["sd-x4", "fast"] = "sd-x4", device: str | None = None
171
+ ) -> ImageUpscaler:
172
+ """Factory function to create an upscaler.
173
+
174
+ Args:
175
+ model_type: Type of upscaler model
176
+ - "sd-x4": Stable Diffusion 4x upscaler (high quality, slower)
177
+ - "fast": Faster alternative (to be implemented)
178
+ device: Device to run on ('cuda', 'cpu', or None for auto)
179
+
180
+ Returns:
181
+ Initialized ImageUpscaler
182
+ """
183
+ model_map = {
184
+ "sd-x4": "stabilityai/stable-diffusion-x4-upscaler",
185
+ # Can add more models here later
186
+ }
187
+
188
+ model_id = model_map.get(model_type, model_map["sd-x4"])
189
+ return ImageUpscaler(model_id=model_id, device=device)
190
+
191
+
192
+ # NOTE:
193
+ # -----
194
+ # When this module is imported as a submodule of ``warp.gradio_app.models``
195
+ # (e.g. via ``from warp.gradio_app.models import upscaler``), Python normally
196
+ # caches it as an attribute on the parent package. That caching can interfere
197
+ # with tests that manipulate ``sys.modules`` to simulate import failures
198
+ # (like removing ``torch``/``diffusers`` and re-importing this module).
199
+ #
200
+ # To ensure those tests can reliably exercise the fallback path, we avoid
201
+ # permanently caching this submodule on the parent package by removing the
202
+ # attribute if it exists. The module itself remains available via
203
+ # ``sys.modules['warp.gradio_app.models.upscaler']``.
204
+ try: # Best-effort; never fail import because of this cleanup.
205
+ import sys as _sys
206
+
207
+ _parent_pkg = _sys.modules.get("warp.gradio_app.models")
208
+ if _parent_pkg is not None and hasattr(_parent_pkg, "upscaler"):
209
+ delattr(_parent_pkg, "upscaler")
210
+ except Exception:
211
+ pass
warp/gradio_app/upscale_compare_tab.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Advanced upscaling tab with before/after comparison and detailed metrics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw
10
+
11
+ if TYPE_CHECKING:
12
+ from .models.upscaler import ImageUpscaler
13
+
14
+ try:
15
+ from .models.upscaler import create_upscaler
16
+
17
+ UPSCALER_AVAILABLE = True
18
+ except (ImportError, ModuleNotFoundError):
19
+ create_upscaler = None # type: ignore[assignment]
20
+ UPSCALER_AVAILABLE = False
21
+
22
+
23
+ # Global upscaler instance (lazy load)
24
+ _upscaler: ImageUpscaler | None = None
25
+
26
+ # Configuration
27
+ MAX_INPUT_WIDTH = 1024
28
+ MAX_INPUT_HEIGHT = 1024
29
+ UPSCALE_FACTOR = 4
30
+ QUALITY_PROMPT = "ultra realistic, natural contrast, high clarity, clean skin texture, professional medical photography"
31
+
32
+
33
+ def _get_upscaler() -> ImageUpscaler | None:
34
+ """Lazy load upscaler on first use."""
35
+ global _upscaler
36
+ if _upscaler is None and UPSCALER_AVAILABLE:
37
+ try:
38
+ _upscaler = create_upscaler(model_type="sd-x4")
39
+ except Exception as e:
40
+ raise RuntimeError(f"Failed to load upscaler: {e}") from e
41
+ return _upscaler
42
+
43
+
44
+ def _validate_and_resize_image(
45
+ img: np.ndarray | Image.Image, max_w: int = MAX_INPUT_WIDTH, max_h: int = MAX_INPUT_HEIGHT
46
+ ) -> Image.Image:
47
+ """Convert and validate image, resize if needed."""
48
+ # Convert numpy to PIL if needed
49
+ if isinstance(img, np.ndarray):
50
+ img = Image.fromarray(img.astype("uint8"))
51
+ elif not isinstance(img, Image.Image):
52
+ raise ValueError(f"Invalid image type: {type(img)}")
53
+
54
+ # Convert to RGB
55
+ if img.mode != "RGB":
56
+ img = img.convert("RGB")
57
+
58
+ # Check and resize if too large
59
+ w, h = img.size
60
+ if w > max_w or h > max_h:
61
+ img.thumbnail((max_w, max_h), Image.Resampling.LANCZOS)
62
+
63
+ return img
64
+
65
+
66
+ def _create_comparison_grid(
67
+ before_orig: Image.Image, after_orig: Image.Image, before_up: Image.Image, after_up: Image.Image
68
+ ) -> Image.Image:
69
+ """Create a 2x2 grid showing before/after, original/upscaled."""
70
+
71
+ # All upscaled images should be the same size (4x original)
72
+ # For display, we'll scale them down to fit alongside originals
73
+
74
+ orig_w, orig_h = before_orig.size
75
+ up_display_w, up_display_h = before_up.size # Should be 4x larger
76
+
77
+ # Create display versions of upscaled (scaled down slightly for display)
78
+ display_scale = 0.5 # Show upscaled at 2x (half of 4x)
79
+ display_w = int(up_display_w * display_scale)
80
+ display_h = int(up_display_h * display_scale)
81
+
82
+ before_up_display = before_up.resize((display_w, display_h), Image.Resampling.LANCZOS)
83
+ after_up_display = after_up.resize((display_w, display_h), Image.Resampling.LANCZOS)
84
+
85
+ # Create grid background
86
+ grid_w = display_w * 2 + 40 # padding
87
+ grid_h = display_h * 2 + 80 # padding + title space
88
+
89
+ grid = Image.new("RGB", (grid_w, grid_h), color=(30, 30, 30))
90
+ draw = ImageDraw.Draw(grid)
91
+
92
+ # Add labels (simple text, no font to avoid system dependencies)
93
+ label_y = 10
94
+ draw.text((10, label_y), "BEFORE (Orig → Upscaled 2x)", fill=(255, 150, 0))
95
+ draw.text((display_w + 20, label_y), "AFTER (Orig → Upscaled 2x)", fill=(255, 150, 0))
96
+
97
+ # Paste images
98
+ paste_y = 40
99
+ grid.paste(before_orig, (10, paste_y))
100
+ grid.paste(before_up_display, (10, paste_y + orig_h + 10))
101
+
102
+ grid.paste(after_orig, (display_w + 20, paste_y))
103
+ grid.paste(after_up_display, (display_w + 20, paste_y + orig_h + 10))
104
+
105
+ return grid
106
+
107
+
108
+ def upscale_and_compare(
109
+ before_img, after_img, prompt: str = QUALITY_PROMPT, num_steps: int = 50, guidance: float = 7.5
110
+ ) -> tuple[Image.Image, Image.Image, Image.Image, str]:
111
+ """Upscale before/after pair with detailed comparison.
112
+
113
+ Args:
114
+ before_img: Before image (numpy or PIL)
115
+ after_img: After image (numpy or PIL)
116
+ prompt: Quality prompt for upscaling
117
+ num_steps: Inference steps (20-100)
118
+ guidance: Guidance scale (1.0-15.0)
119
+
120
+ Returns:
121
+ Tuple of (before_upscaled, after_upscaled, comparison_grid, status_message)
122
+ """
123
+
124
+ if not UPSCALER_AVAILABLE:
125
+ return (
126
+ None,
127
+ None,
128
+ None,
129
+ "❌ Upscaler not available. Install: pip install torch diffusers transformers",
130
+ ) # type: ignore[return-value]
131
+
132
+ if before_img is None or after_img is None:
133
+ return None, None, None, "❌ Please upload both before and after images" # type: ignore[return-value]
134
+
135
+ try:
136
+ # Get upscaler instance
137
+ upscaler = _get_upscaler()
138
+ if upscaler is None:
139
+ return None, None, None, "❌ Upscaler not available" # type: ignore[return-value]
140
+
141
+ # Validate and resize inputs
142
+ before_pil = _validate_and_resize_image(before_img)
143
+ after_pil = _validate_and_resize_image(after_img)
144
+
145
+ orig_before_size = before_pil.size
146
+ orig_after_size = after_pil.size
147
+
148
+ # Upscale both images
149
+ print(f"Upscaling before image ({orig_before_size})...")
150
+ before_upscaled = upscaler.upscale(
151
+ before_pil, prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance
152
+ )
153
+
154
+ print(f"Upscaling after image ({orig_after_size})...")
155
+ after_upscaled = upscaler.upscale(
156
+ after_pil, prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance
157
+ )
158
+
159
+ # Create comparison grid
160
+ comparison = _create_comparison_grid(before_pil, after_pil, before_upscaled, after_upscaled)
161
+
162
+ # Build status message
163
+ status = (
164
+ f"✅ Successfully upscaled both images!\n\n"
165
+ f"Before: {orig_before_size} → {before_upscaled.size}\n"
166
+ f"After: {orig_after_size} → {after_upscaled.size}\n"
167
+ f"Upscale Factor: {UPSCALE_FACTOR}x\n"
168
+ f"Steps: {num_steps} | Guidance: {guidance}\n"
169
+ f"\nNote: Comparison shows upscaled images at 2x (50% of 4x for display)"
170
+ )
171
+
172
+ return before_upscaled, after_upscaled, comparison, status
173
+
174
+ except Exception as e:
175
+ error_msg = f"❌ Error during upscaling: {str(e)}"
176
+ print(error_msg)
177
+ return None, None, None, error_msg # type: ignore[return-value]
178
+
179
+
180
+ def build_ui() -> None:
181
+ """Build the advanced upscaling UI tab."""
182
+
183
+ gr.Markdown("### Upscale Before/After Images (Max Detail & Clarity)")
184
+ gr.Markdown(
185
+ "Upload medical before/after photos to upscale them 4x using Stable Diffusion x4 Upscaler. "
186
+ "Both images are processed with identical parameters for fair comparison.\n\n"
187
+ f"⚠️ **Note:** Processing takes 30-60 seconds per image (CPU) or 5-10 seconds (GPU). "
188
+ f"Maximum input size: {MAX_INPUT_WIDTH}x{MAX_INPUT_HEIGHT}px (automatically resized if larger)."
189
+ )
190
+
191
+ with gr.Row():
192
+ with gr.Column():
193
+ gr.Markdown("#### Original Images")
194
+ before_input = gr.Image(label="Before Image", type="numpy")
195
+ after_input = gr.Image(label="After Image", type="numpy")
196
+
197
+ # Parameters
198
+ prompt_input = gr.Textbox(
199
+ label="Quality Prompt",
200
+ value=QUALITY_PROMPT,
201
+ placeholder="Describe desired image quality...",
202
+ lines=3,
203
+ )
204
+
205
+ with gr.Column():
206
+ gr.Markdown("#### Upscaled Results (4x)")
207
+ before_output = gr.Image(label="Upscaled Before", type="pil")
208
+ after_output = gr.Image(label="Upscaled After", type="pil")
209
+
210
+ with gr.Row():
211
+ with gr.Column(scale=1):
212
+ num_steps = gr.Slider(
213
+ minimum=20, maximum=100, value=50, step=5, label="Inference Steps"
214
+ )
215
+ with gr.Column(scale=1):
216
+ guidance_scale = gr.Slider(
217
+ minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale"
218
+ )
219
+
220
+ upscale_btn = gr.Button("🚀 Upscale Both", variant="primary", size="lg")
221
+ upscale_status = gr.Textbox(label="Status", interactive=False, lines=4)
222
+
223
+ gr.Markdown("#### Side-by-Side Comparison")
224
+ comparison_output = gr.Image(label="Comparison Grid", type="pil")
225
+
226
+ # Button click handler
227
+ upscale_btn.click(
228
+ fn=upscale_and_compare,
229
+ inputs=[before_input, after_input, prompt_input, num_steps, guidance_scale],
230
+ outputs=[before_output, after_output, comparison_output, upscale_status],
231
+ )
warp/inference/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """WARP Inference - AI model inference interfaces."""
2
+
3
+ from .hf_client import HuggingFaceAPIError, HuggingFaceClient, create_client, infer_image
4
+
5
+ __all__ = [
6
+ "HuggingFaceClient",
7
+ "HuggingFaceAPIError",
8
+ "create_client",
9
+ "infer_image",
10
+ ]
warp/inference/background_removal.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Background Removal - Unified interface for background removal models with performance tracking.
3
+ """
4
+
5
+ import time
6
+ from dataclasses import dataclass
7
+ from typing import Dict, Optional
8
+
9
+ from PIL import Image
10
+
11
+ from warp.inference.hf_client import HuggingFaceClient
12
+ from warp.inference.local_client import LocalBackgroundRemovalClient
13
+ from warp.inference.metrics import calculate_comprehensive_metrics, calculate_weighted_quality_score
14
+ from warp.models import get_model, list_model_names
15
+
16
+
17
+ @dataclass
18
+ class BackgroundRemovalResult:
19
+ """Result of a background removal operation."""
20
+
21
+ output_image: Image.Image
22
+ model_name: str
23
+ processing_time_ms: int
24
+ success: bool = True
25
+ error_message: Optional[str] = None
26
+ input_size: tuple = (0, 0)
27
+ output_size: tuple = (0, 0)
28
+
29
+ # Quality metrics
30
+ edge_quality: Optional[float] = None
31
+ ssim: Optional[float] = None
32
+ psnr: Optional[float] = None
33
+ transparency_coverage: Optional[float] = None
34
+ mask_accuracy: Optional[float] = None
35
+ weighted_quality_score: Optional[float] = None
36
+
37
+
38
+ class BackgroundRemovalEngine:
39
+ """Engine for running background removal with multiple models."""
40
+
41
+ def __init__(self, api_key: Optional[str] = None):
42
+ """
43
+ Initialize background removal engine.
44
+
45
+ Args:
46
+ api_key: Hugging Face API key (optional, defaults to env var)
47
+ """
48
+ self.hf_client = None
49
+ self.local_clients = {} # Cache local model sessions
50
+ self.api_key = api_key
51
+ self.available_models = list_model_names("bg_removal")
52
+
53
+ def remove_background(
54
+ self, image: Image.Image, model_name: Optional[str] = None
55
+ ) -> BackgroundRemovalResult:
56
+ """
57
+ Remove background from an image using the specified model.
58
+
59
+ Args:
60
+ image: Input PIL Image
61
+ model_name: Model name from registry (uses default if None)
62
+
63
+ Returns:
64
+ BackgroundRemovalResult with output image and metrics
65
+
66
+ Raises:
67
+ ValueError: If model not found
68
+ """
69
+ # Get model config
70
+ if model_name is None:
71
+ from warp.models import get_default_model
72
+
73
+ model_config = get_default_model("bg_removal")
74
+ if not model_config:
75
+ raise ValueError("No default background removal model found")
76
+ else:
77
+ model_config = get_model(model_name)
78
+ if not model_config:
79
+ raise ValueError(f"Model '{model_name}' not found in registry")
80
+
81
+ input_size = image.size
82
+ start_time = time.time()
83
+
84
+ try:
85
+ # Route to appropriate client based on provider
86
+ if model_config.provider == "local":
87
+ # Use local rembg
88
+ if model_config.model_id not in self.local_clients:
89
+ self.local_clients[model_config.model_id] = LocalBackgroundRemovalClient(
90
+ model_name=model_config.model_id
91
+ )
92
+ client = self.local_clients[model_config.model_id]
93
+ output_image = client.remove_background(image)
94
+ else:
95
+ # Use Hugging Face API
96
+ if self.hf_client is None:
97
+ self.hf_client = HuggingFaceClient(api_key=self.api_key)
98
+ output_image = self.hf_client.infer_image(
99
+ model_id=model_config.model_id,
100
+ image=image,
101
+ parameters=model_config.default_parameters,
102
+ )
103
+
104
+ processing_time_ms = int((time.time() - start_time) * 1000)
105
+
106
+ # Calculate quality metrics
107
+ metrics = calculate_comprehensive_metrics(
108
+ output_image=output_image,
109
+ input_image=image
110
+ )
111
+ weighted_score = calculate_weighted_quality_score(metrics)
112
+
113
+ return BackgroundRemovalResult(
114
+ output_image=output_image,
115
+ model_name=model_config.name,
116
+ processing_time_ms=processing_time_ms,
117
+ success=True,
118
+ input_size=input_size,
119
+ output_size=output_image.size,
120
+ edge_quality=metrics.get('edge_quality'),
121
+ ssim=metrics.get('ssim'),
122
+ psnr=metrics.get('psnr'),
123
+ transparency_coverage=metrics.get('transparency_coverage'),
124
+ mask_accuracy=metrics.get('mask_accuracy'),
125
+ weighted_quality_score=weighted_score,
126
+ )
127
+
128
+ except Exception as e:
129
+ processing_time_ms = int((time.time() - start_time) * 1000)
130
+ return BackgroundRemovalResult(
131
+ output_image=image, # Return original on error
132
+ model_name=model_config.name,
133
+ processing_time_ms=processing_time_ms,
134
+ success=False,
135
+ error_message=str(e),
136
+ input_size=input_size,
137
+ output_size=image.size,
138
+ )
139
+
140
+ def compare_models(
141
+ self, image: Image.Image, model_names: Optional[list] = None
142
+ ) -> Dict[str, BackgroundRemovalResult]:
143
+ """
144
+ Compare multiple background removal models on the same image.
145
+
146
+ Args:
147
+ image: Input PIL Image
148
+ model_names: List of model names to compare (uses all if None)
149
+
150
+ Returns:
151
+ Dictionary mapping model names to results
152
+ """
153
+ if model_names is None:
154
+ model_names = self.available_models
155
+
156
+ results = {}
157
+ for model_name in model_names:
158
+ print(f"Processing with {model_name}...")
159
+ result = self.remove_background(image, model_name)
160
+ results[model_name] = result
161
+
162
+ if result.success:
163
+ print(f" ✓ {result.processing_time_ms}ms")
164
+ else:
165
+ print(f" ✗ Error: {result.error_message}")
166
+
167
+ return results
168
+
169
+
170
+ # ============================================================================
171
+ # Convenience Functions
172
+ # ============================================================================
173
+
174
+
175
+ def remove_background(
176
+ image: Image.Image, model_name: Optional[str] = None, api_key: Optional[str] = None
177
+ ) -> Image.Image:
178
+ """
179
+ Quick background removal function.
180
+
181
+ Args:
182
+ image: Input PIL Image
183
+ model_name: Model name (uses default if None)
184
+ api_key: Optional API key
185
+
186
+ Returns:
187
+ Output PIL Image with background removed
188
+ """
189
+ engine = BackgroundRemovalEngine(api_key=api_key)
190
+ result = engine.remove_background(image, model_name)
191
+
192
+ if not result.success:
193
+ raise RuntimeError(f"Background removal failed: {result.error_message}")
194
+
195
+ return result.output_image
196
+
197
+
198
+ if __name__ == "__main__":
199
+ # Demo/test code
200
+ import os
201
+
202
+ print("=== Background Removal Engine ===\n")
203
+
204
+ engine = BackgroundRemovalEngine()
205
+ print(f"Available models: {', '.join(engine.available_models)}\n")
206
+
207
+ # Test with a sample image if available
208
+ test_image_path = "data/scrapedimages/drleedy.com"
209
+ if os.path.exists(test_image_path):
210
+ from warp.data import ImageLoader
211
+
212
+ loader = ImageLoader()
213
+ images = loader.get_random_images("drleedy.com", n=1)
214
+
215
+ if images:
216
+ print(f"Testing with image: {images[0]}")
217
+ test_image = loader.load_image(images[0])
218
+
219
+ # Test single model
220
+ result = engine.remove_background(test_image, model_name="rmbg-1.4")
221
+ print(f"\nTest result:")
222
+ print(f" Model: {result.model_name}")
223
+ print(f" Time: {result.processing_time_ms}ms")
224
+ print(f" Success: {result.success}")
225
+ else:
226
+ print("No test images found. Run with actual images to test.")
warp/inference/hf_client.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face API Client - Unified interface for calling Hugging Face Inference API
3
+ with retry logic, rate limiting, and error handling.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ from io import BytesIO
9
+ from typing import Any, Dict, Optional
10
+
11
+ import requests
12
+ from PIL import Image
13
+
14
+ # Default configuration
15
+ DEFAULT_API_URL = "https://api-inference.huggingface.co/models"
16
+ DEFAULT_TIMEOUT = 120 # seconds
17
+ DEFAULT_MAX_RETRIES = 3
18
+ DEFAULT_RETRY_DELAY = 2 # seconds
19
+
20
+
21
+ class HuggingFaceAPIError(Exception):
22
+ """Custom exception for Hugging Face API errors."""
23
+
24
+ pass
25
+
26
+
27
+ class HuggingFaceClient:
28
+ """Client for Hugging Face Inference API."""
29
+
30
+ def __init__(
31
+ self,
32
+ api_key: Optional[str] = None,
33
+ max_retries: int = DEFAULT_MAX_RETRIES,
34
+ retry_delay: int = DEFAULT_RETRY_DELAY,
35
+ timeout: int = DEFAULT_TIMEOUT,
36
+ ):
37
+ """
38
+ Initialize Hugging Face API client.
39
+
40
+ Args:
41
+ api_key: Hugging Face API key (defaults to HF_API_KEY env var)
42
+ max_retries: Maximum number of retry attempts
43
+ retry_delay: Delay between retries in seconds
44
+ timeout: Request timeout in seconds
45
+ """
46
+ self.api_key = api_key or os.getenv("HF_API_KEY")
47
+ if not self.api_key:
48
+ raise ValueError(
49
+ "Hugging Face API key not provided. "
50
+ "Set HF_API_KEY environment variable or pass api_key parameter."
51
+ )
52
+
53
+ self.max_retries = max_retries
54
+ self.retry_delay = retry_delay
55
+ self.timeout = timeout
56
+
57
+ self.headers = {"Authorization": f"Bearer {self.api_key}"}
58
+
59
+ def _build_url(self, model_id: str) -> str:
60
+ """Build the full API URL for a model."""
61
+ return f"{DEFAULT_API_URL}/{model_id}"
62
+
63
+ def _handle_response(self, response: requests.Response) -> Any:
64
+ """
65
+ Handle API response and raise appropriate errors.
66
+
67
+ Args:
68
+ response: Response from API
69
+
70
+ Returns:
71
+ Response content (bytes or dict)
72
+
73
+ Raises:
74
+ HuggingFaceAPIError: If API returns an error
75
+ """
76
+ if response.status_code == 200:
77
+ # Check content type
78
+ content_type = response.headers.get("content-type", "")
79
+ if "application/json" in content_type:
80
+ return response.json()
81
+ else:
82
+ return response.content
83
+
84
+ # Handle errors
85
+ error_msg = f"API request failed with status {response.status_code}"
86
+ try:
87
+ error_detail = response.json()
88
+ if "error" in error_detail:
89
+ error_msg = f"{error_msg}: {error_detail['error']}"
90
+ except Exception:
91
+ error_msg = f"{error_msg}: {response.text[:200]}"
92
+
93
+ raise HuggingFaceAPIError(error_msg)
94
+
95
+ def infer(
96
+ self,
97
+ model_id: str,
98
+ inputs: Any,
99
+ parameters: Optional[Dict] = None,
100
+ return_json: bool = False,
101
+ ) -> Any:
102
+ """
103
+ Call Hugging Face Inference API with retry logic.
104
+
105
+ Args:
106
+ model_id: Hugging Face model ID (e.g., 'briaai/RMBG-2.0')
107
+ inputs: Input data (can be bytes, PIL Image, or dict)
108
+ parameters: Additional parameters for the model
109
+ return_json: Whether to expect JSON response
110
+
111
+ Returns:
112
+ API response (bytes for images, dict for JSON)
113
+
114
+ Raises:
115
+ HuggingFaceAPIError: If all retries fail
116
+ """
117
+ url = self._build_url(model_id)
118
+ payload = self._prepare_payload(inputs, parameters)
119
+
120
+ for attempt in range(1, self.max_retries + 1):
121
+ try:
122
+ response = requests.post(
123
+ url, headers=self.headers, data=payload, timeout=self.timeout
124
+ )
125
+
126
+ # Handle model loading
127
+ if response.status_code == 503:
128
+ error_data = response.json()
129
+ if "estimated_time" in error_data:
130
+ wait_time = error_data["estimated_time"]
131
+ print(f"Model loading, waiting {wait_time}s...")
132
+ time.sleep(wait_time + 1)
133
+ continue
134
+
135
+ return self._handle_response(response)
136
+
137
+ except requests.exceptions.RequestException as e:
138
+ if attempt == self.max_retries:
139
+ raise HuggingFaceAPIError(f"Request failed after {attempt} attempts: {e}")
140
+
141
+ print(f"Attempt {attempt} failed: {e}. Retrying in {self.retry_delay}s...")
142
+ time.sleep(self.retry_delay)
143
+
144
+ raise HuggingFaceAPIError(f"All {self.max_retries} retries exhausted")
145
+
146
+ def _prepare_payload(self, inputs: Any, parameters: Optional[Dict] = None) -> bytes:
147
+ """
148
+ Prepare payload for API request.
149
+
150
+ Args:
151
+ inputs: Input data (bytes, PIL Image, or dict)
152
+ parameters: Additional parameters
153
+
154
+ Returns:
155
+ Bytes payload
156
+ """
157
+ # Convert PIL Image to bytes
158
+ if isinstance(inputs, Image.Image):
159
+ buffer = BytesIO()
160
+ inputs.save(buffer, format="PNG")
161
+ inputs = buffer.getvalue()
162
+
163
+ # If already bytes, return as-is
164
+ if isinstance(inputs, bytes):
165
+ return inputs
166
+
167
+ # For dict/json inputs, serialize
168
+ if isinstance(inputs, dict):
169
+ import json
170
+
171
+ return json.dumps(inputs).encode("utf-8")
172
+
173
+ raise ValueError(f"Unsupported input type: {type(inputs)}")
174
+
175
+ def infer_image(self, model_id: str, image: Image.Image, parameters: Optional[Dict] = None) -> Image.Image:
176
+ """
177
+ Call inference API with a PIL Image and return a PIL Image.
178
+
179
+ Args:
180
+ model_id: Hugging Face model ID
181
+ image: Input PIL Image
182
+ parameters: Additional parameters
183
+
184
+ Returns:
185
+ Output PIL Image
186
+
187
+ Raises:
188
+ HuggingFaceAPIError: If inference fails
189
+ """
190
+ response_bytes = self.infer(model_id, image, parameters)
191
+ return Image.open(BytesIO(response_bytes))
192
+
193
+ def health_check(self, model_id: str) -> bool:
194
+ """
195
+ Check if a model is available and loaded.
196
+
197
+ Args:
198
+ model_id: Hugging Face model ID
199
+
200
+ Returns:
201
+ True if model is accessible
202
+ """
203
+ try:
204
+ url = self._build_url(model_id)
205
+ response = requests.get(url, headers=self.headers, timeout=10)
206
+ return response.status_code in [200, 503] # 503 means loading
207
+ except Exception:
208
+ return False
209
+
210
+
211
+ # ============================================================================
212
+ # Convenience Functions
213
+ # ============================================================================
214
+
215
+
216
+ def create_client(api_key: Optional[str] = None) -> HuggingFaceClient:
217
+ """Create a HuggingFace client with default settings."""
218
+ return HuggingFaceClient(api_key=api_key)
219
+
220
+
221
+ def infer_image(
222
+ model_id: str, image: Image.Image, api_key: Optional[str] = None, parameters: Optional[Dict] = None
223
+ ) -> Image.Image:
224
+ """
225
+ One-shot image inference with a Hugging Face model.
226
+
227
+ Args:
228
+ model_id: Hugging Face model ID
229
+ image: Input PIL Image
230
+ api_key: Optional API key (defaults to env var)
231
+ parameters: Additional parameters
232
+
233
+ Returns:
234
+ Output PIL Image
235
+ """
236
+ client = create_client(api_key)
237
+ return client.infer_image(model_id, image, parameters)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ # Demo/test code
242
+ client = create_client()
243
+
244
+ # Test model availability
245
+ test_model = "briaai/RMBG-1.4"
246
+ print(f"Testing connection to {test_model}...")
247
+ is_available = client.health_check(test_model)
248
+ print(f"Model available: {is_available}")
249
+
250
+ if is_available:
251
+ print("\n✓ Hugging Face API client ready")
252
+ else:
253
+ print("\n✗ Failed to connect to model")
warp/inference/local_client.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Local Background Removal Client - Uses rembg for local processing.
3
+ """
4
+
5
+ from typing import Optional
6
+ from PIL import Image
7
+ from rembg import remove, new_session
8
+
9
+
10
+ class LocalBackgroundRemovalClient:
11
+ """Client for local background removal using rembg."""
12
+
13
+ # Available models in rembg
14
+ MODELS = {
15
+ "u2net": "General purpose model (fast)",
16
+ "u2netp": "Lightweight version of u2net",
17
+ "u2net_human_seg": "Optimized for human segmentation",
18
+ "u2net_cloth_seg": "Optimized for cloth segmentation",
19
+ "silueta": "General segmentation",
20
+ "isnet-general-use": "Improved segmentation (recommended)",
21
+ "isnet-anime": "Anime/illustration specific",
22
+ }
23
+
24
+ def __init__(self, model_name: str = "isnet-general-use"):
25
+ """
26
+ Initialize local client with a specific model.
27
+
28
+ Args:
29
+ model_name: Name of the rembg model to use
30
+ """
31
+ self.model_name = model_name
32
+ self.session = None
33
+
34
+ def _get_session(self):
35
+ """Lazy-load the model session."""
36
+ if self.session is None:
37
+ self.session = new_session(self.model_name)
38
+ return self.session
39
+
40
+ def remove_background(self, image: Image.Image) -> Image.Image:
41
+ """
42
+ Remove background from an image.
43
+
44
+ Args:
45
+ image: Input PIL Image
46
+
47
+ Returns:
48
+ PIL Image with background removed (RGBA)
49
+ """
50
+ session = self._get_session()
51
+ output = remove(image, session=session)
52
+ return output
53
+
54
+ @classmethod
55
+ def list_models(cls):
56
+ """List available models."""
57
+ return cls.MODELS
58
+
59
+
60
+ if __name__ == "__main__":
61
+ # Test
62
+ print("Available rembg models:")
63
+ for name, desc in LocalBackgroundRemovalClient.MODELS.items():
64
+ print(f" - {name}: {desc}")
warp/inference/metrics.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quality Metrics - Calculate image quality metrics for background removal evaluation.
3
+ """
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from typing import Dict, Tuple, Optional
8
+ import cv2
9
+
10
+
11
+ def calculate_edge_quality(image: Image.Image, mask: Optional[Image.Image] = None) -> float:
12
+ """
13
+ Calculate edge quality score based on gradient strength.
14
+
15
+ Higher score = sharper edges.
16
+
17
+ Args:
18
+ image: PIL Image (RGBA for background removal)
19
+ mask: Optional mask (if available separately)
20
+
21
+ Returns:
22
+ Edge quality score (0.0 to 1.0)
23
+ """
24
+ # Convert to numpy array
25
+ if image.mode == 'RGBA':
26
+ # Use alpha channel as mask
27
+ img_array = np.array(image)
28
+ mask_array = img_array[:, :, 3]
29
+ elif mask:
30
+ mask_array = np.array(mask.convert('L'))
31
+ else:
32
+ # Convert to grayscale if no alpha
33
+ mask_array = np.array(image.convert('L'))
34
+
35
+ # Calculate gradients using Sobel
36
+ grad_x = cv2.Sobel(mask_array, cv2.CV_64F, 1, 0, ksize=3)
37
+ grad_y = cv2.Sobel(mask_array, cv2.CV_64F, 0, 1, ksize=3)
38
+
39
+ # Gradient magnitude
40
+ gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
41
+
42
+ # Normalize to 0-1
43
+ if gradient_magnitude.max() > 0:
44
+ edge_score = np.mean(gradient_magnitude) / gradient_magnitude.max()
45
+ else:
46
+ edge_score = 0.0
47
+
48
+ return float(edge_score)
49
+
50
+
51
+ def calculate_mask_accuracy(pred_mask: Image.Image, gt_mask: Image.Image) -> float:
52
+ """
53
+ Calculate mask accuracy (IoU - Intersection over Union).
54
+
55
+ Only applicable when ground truth mask is available.
56
+
57
+ Args:
58
+ pred_mask: Predicted mask (alpha channel or grayscale)
59
+ gt_mask: Ground truth mask
60
+
61
+ Returns:
62
+ IoU score (0.0 to 1.0)
63
+ """
64
+ # Convert to binary masks
65
+ pred_array = np.array(pred_mask.convert('L')) > 127
66
+ gt_array = np.array(gt_mask.convert('L')) > 127
67
+
68
+ # Calculate IoU
69
+ intersection = np.logical_and(pred_array, gt_array).sum()
70
+ union = np.logical_or(pred_array, gt_array).sum()
71
+
72
+ if union == 0:
73
+ return 0.0
74
+
75
+ iou = intersection / union
76
+ return float(iou)
77
+
78
+
79
+ def calculate_ssim(image1: Image.Image, image2: Image.Image) -> float:
80
+ """
81
+ Calculate Structural Similarity Index (SSIM) between two images.
82
+
83
+ Used to compare processed image with reference.
84
+
85
+ Args:
86
+ image1: First image
87
+ image2: Second image
88
+
89
+ Returns:
90
+ SSIM score (-1.0 to 1.0, higher is better)
91
+ """
92
+ from skimage.metrics import structural_similarity as ssim
93
+
94
+ # Convert to grayscale numpy arrays
95
+ img1_gray = np.array(image1.convert('L'))
96
+ img2_gray = np.array(image2.convert('L'))
97
+
98
+ # Ensure same size
99
+ if img1_gray.shape != img2_gray.shape:
100
+ # Resize to match
101
+ img2_gray = cv2.resize(img2_gray, (img1_gray.shape[1], img1_gray.shape[0]))
102
+
103
+ # Calculate SSIM
104
+ score = ssim(img1_gray, img2_gray)
105
+ return float(score)
106
+
107
+
108
+ def calculate_psnr(image1: Image.Image, image2: Image.Image) -> float:
109
+ """
110
+ Calculate Peak Signal-to-Noise Ratio (PSNR) between two images.
111
+
112
+ Higher PSNR = better quality (less noise/distortion).
113
+
114
+ Args:
115
+ image1: First image
116
+ image2: Second image
117
+
118
+ Returns:
119
+ PSNR in dB (higher is better, typically 20-50)
120
+ """
121
+ # Convert to numpy arrays
122
+ img1_array = np.array(image1.convert('RGB')).astype(float)
123
+ img2_array = np.array(image2.convert('RGB')).astype(float)
124
+
125
+ # Ensure same size
126
+ if img1_array.shape != img2_array.shape:
127
+ img2_array = cv2.resize(img2_array, (img1_array.shape[1], img1_array.shape[0]))
128
+
129
+ # Calculate MSE
130
+ mse = np.mean((img1_array - img2_array) ** 2)
131
+
132
+ if mse == 0:
133
+ return 100.0 # Perfect match
134
+
135
+ # Calculate PSNR
136
+ max_pixel = 255.0
137
+ psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
138
+
139
+ return float(psnr)
140
+
141
+
142
+ def calculate_transparency_coverage(image: Image.Image) -> float:
143
+ """
144
+ Calculate percentage of image that is transparent (for RGBA images).
145
+
146
+ Useful for background removal evaluation.
147
+
148
+ Args:
149
+ image: RGBA PIL Image
150
+
151
+ Returns:
152
+ Transparency coverage (0.0 to 1.0)
153
+ """
154
+ if image.mode != 'RGBA':
155
+ return 0.0
156
+
157
+ img_array = np.array(image)
158
+ alpha_channel = img_array[:, :, 3]
159
+
160
+ # Count transparent pixels (alpha < 10)
161
+ transparent_pixels = (alpha_channel < 10).sum()
162
+ total_pixels = alpha_channel.size
163
+
164
+ coverage = transparent_pixels / total_pixels
165
+ return float(coverage)
166
+
167
+
168
+ def calculate_comprehensive_metrics(
169
+ output_image: Image.Image,
170
+ input_image: Optional[Image.Image] = None,
171
+ ground_truth_mask: Optional[Image.Image] = None
172
+ ) -> Dict[str, float]:
173
+ """
174
+ Calculate all available quality metrics for an image.
175
+
176
+ Args:
177
+ output_image: Processed image (typically RGBA after bg removal)
178
+ input_image: Original input image (for SSIM/PSNR comparison)
179
+ ground_truth_mask: Ground truth mask (if available, for accuracy)
180
+
181
+ Returns:
182
+ Dictionary of metric names to scores
183
+ """
184
+ metrics = {}
185
+
186
+ # Edge quality (always available)
187
+ metrics['edge_quality'] = calculate_edge_quality(output_image)
188
+
189
+ # Transparency coverage (for RGBA images)
190
+ if output_image.mode == 'RGBA':
191
+ metrics['transparency_coverage'] = calculate_transparency_coverage(output_image)
192
+
193
+ # Comparison metrics (if input provided)
194
+ if input_image:
195
+ try:
196
+ metrics['ssim'] = calculate_ssim(input_image, output_image)
197
+ except Exception as e:
198
+ print(f"SSIM calculation failed: {e}")
199
+ metrics['ssim'] = 0.0
200
+
201
+ try:
202
+ metrics['psnr'] = calculate_psnr(input_image, output_image)
203
+ except Exception as e:
204
+ print(f"PSNR calculation failed: {e}")
205
+ metrics['psnr'] = 0.0
206
+
207
+ # Mask accuracy (if ground truth provided)
208
+ if ground_truth_mask and output_image.mode == 'RGBA':
209
+ try:
210
+ # Extract alpha channel as predicted mask
211
+ alpha = output_image.split()[3]
212
+ metrics['mask_accuracy'] = calculate_mask_accuracy(alpha, ground_truth_mask)
213
+ except Exception as e:
214
+ print(f"Mask accuracy calculation failed: {e}")
215
+ metrics['mask_accuracy'] = 0.0
216
+
217
+ return metrics
218
+
219
+
220
+ def calculate_weighted_quality_score(metrics: Dict[str, float]) -> float:
221
+ """
222
+ Calculate overall weighted quality score from individual metrics.
223
+
224
+ Formula: (mask_accuracy × 0.6 + edge_quality × 0.4) × success_rate
225
+
226
+ If mask_accuracy unavailable, uses edge_quality only.
227
+
228
+ Args:
229
+ metrics: Dictionary of metric scores
230
+
231
+ Returns:
232
+ Weighted quality score (0.0 to 1.0)
233
+ """
234
+ edge_quality = metrics.get('edge_quality', 0.0)
235
+ mask_accuracy = metrics.get('mask_accuracy')
236
+
237
+ if mask_accuracy is not None:
238
+ # Use weighted combination
239
+ score = (mask_accuracy * 0.6 + edge_quality * 0.4)
240
+ else:
241
+ # Use edge quality only
242
+ score = edge_quality
243
+
244
+ return float(score)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ # Demo
249
+ print("=== Quality Metrics Module ===\n")
250
+ print("Available metrics:")
251
+ print(" - Edge Quality: Gradient-based edge sharpness")
252
+ print(" - Mask Accuracy: IoU with ground truth (if available)")
253
+ print(" - SSIM: Structural similarity (0-1)")
254
+ print(" - PSNR: Peak signal-to-noise ratio (dB)")
255
+ print(" - Transparency Coverage: % of transparent pixels")
256
+ print(" - Weighted Quality Score: Combined metric for ranking")
warp/inference/upscaler.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Upscaler Interface - Unified interface for image upscaling models.
3
+
4
+ Placeholder for Phase 3 upscaling integration.
5
+ """
6
+
7
+ import time
8
+ from dataclass import dataclass
9
+ from typing import Optional
10
+ from PIL import Image
11
+
12
+ from warp.models import get_model
13
+
14
+
15
+ @dataclass
16
+ class UpscaleResult:
17
+ """Result of an upscaling operation."""
18
+
19
+ output_image: Image.Image
20
+ model_name: str
21
+ processing_time_ms: int
22
+ success: bool = True
23
+ error_message: Optional[str] = None
24
+ input_size: tuple = (0, 0)
25
+ output_size: tuple = (0, 0)
26
+ scale_factor: int = 4
27
+
28
+
29
+ class UpscalerEngine:
30
+ """Engine for running upscaling models."""
31
+
32
+ def __init__(self, api_key: Optional[str] = None):
33
+ """
34
+ Initialize upscaler engine.
35
+
36
+ Args:
37
+ api_key: API key for external services (if needed)
38
+ """
39
+ self.api_key = api_key
40
+ # TODO: Initialize clients when upscaling models are added
41
+
42
+ def upscale(
43
+ self,
44
+ image: Image.Image,
45
+ model_name: Optional[str] = None,
46
+ scale_factor: int = 4
47
+ ) -> UpscaleResult:
48
+ """
49
+ Upscale an image using the specified model.
50
+
51
+ Args:
52
+ image: Input PIL Image
53
+ model_name: Model name from registry (uses default if None)
54
+ scale_factor: Upscaling factor (2x, 4x, etc.)
55
+
56
+ Returns:
57
+ UpscaleResult with output image and metrics
58
+
59
+ Raises:
60
+ ValueError: If model not found
61
+ NotImplementedError: If upscaling not yet implemented
62
+ """
63
+ # Get model config
64
+ if model_name is None:
65
+ from warp.models import get_default_model
66
+ model_config = get_default_model("upscale")
67
+ if not model_config:
68
+ raise ValueError("No default upscaling model found")
69
+ else:
70
+ model_config = get_model(model_name)
71
+ if not model_config:
72
+ raise ValueError(f"Model '{model_name}' not found in registry")
73
+
74
+ # TODO: Implement actual upscaling in Phase 3
75
+ raise NotImplementedError(
76
+ f"Upscaling with {model_config.name} not yet implemented. "
77
+ "This will be added in Phase 3."
78
+ )
79
+
80
+
81
+ # Convenience function
82
+ def upscale_image(
83
+ image: Image.Image,
84
+ model_name: Optional[str] = None,
85
+ scale_factor: int = 4,
86
+ api_key: Optional[str] = None
87
+ ) -> Image.Image:
88
+ """
89
+ Quick upscaling function.
90
+
91
+ Args:
92
+ image: Input PIL Image
93
+ model_name: Model name (uses default if None)
94
+ scale_factor: Upscaling factor
95
+ api_key: Optional API key
96
+
97
+ Returns:
98
+ Upscaled PIL Image
99
+ """
100
+ engine = UpscalerEngine(api_key=api_key)
101
+ result = engine.upscale(image, model_name, scale_factor)
102
+
103
+ if not result.success:
104
+ raise RuntimeError(f"Upscaling failed: {result.error_message}")
105
+
106
+ return result.output_image
107
+
108
+
109
+ if __name__ == "__main__":
110
+ print("=== Upscaler Module ===\\n")
111
+ print("Status: Placeholder for Phase 3")
112
+ print("\\nPlanned upscaling models:")
113
+ print(" - Real-ESRGAN 4x")
114
+ print(" - GFPGAN (face restoration)")
115
+ print(" - CodeFormer (face enhancement)")
116
+ print(" - Swin2SR")
117
+ print("\\nUpscaling will be integrated in Phase 3 (Week 4-6)")
warp/models/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WARP Models - Model registry and configuration."""
2
+
3
+ from .registry import (
4
+ BACKGROUND_REMOVAL_MODELS,
5
+ MODEL_REGISTRY,
6
+ UPSCALING_MODELS,
7
+ ModelConfig,
8
+ get_default_model,
9
+ get_model,
10
+ get_models_by_type,
11
+ list_all_models,
12
+ list_model_names,
13
+ )
14
+
15
+ __all__ = [
16
+ "ModelConfig",
17
+ "MODEL_REGISTRY",
18
+ "BACKGROUND_REMOVAL_MODELS",
19
+ "UPSCALING_MODELS",
20
+ "get_model",
21
+ "get_models_by_type",
22
+ "get_default_model",
23
+ "list_all_models",
24
+ "list_model_names",
25
+ ]
warp/models/registry.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Registry - Centralized configuration for all AI models in WARP.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional
7
+
8
+
9
+ @dataclass
10
+ class ModelConfig:
11
+ """Configuration for an AI model."""
12
+
13
+ name: str
14
+ display_name: str
15
+ model_id: str # Hugging Face model ID or path
16
+ provider: str = "huggingface" # 'huggingface', 'replicate', 'local'
17
+ operation_type: str = "bg_removal" # 'bg_removal', 'upscale', 'color_correct'
18
+ description: str = ""
19
+ version: Optional[str] = None
20
+
21
+ # Model capabilities
22
+ input_formats: List[str] = field(default_factory=lambda: ["png", "jpg", "jpeg", "webp"])
23
+ output_format: str = "png"
24
+ max_input_size: int = 2048 # Max dimension in pixels
25
+ requires_gpu: bool = False
26
+
27
+ # Default parameters
28
+ default_parameters: Dict = field(default_factory=dict)
29
+
30
+ # Performance hints
31
+ estimated_time_ms: int = 3000
32
+ is_default: bool = False
33
+ is_active: bool = True
34
+
35
+
36
+ # ============================================================================
37
+ # Background Removal Models
38
+ # ============================================================================
39
+
40
+ BACKGROUND_REMOVAL_MODELS = {
41
+ # Local models (using rembg)
42
+ "isnet-general": ModelConfig(
43
+ name="isnet-general",
44
+ display_name="ISNet General (Local)",
45
+ model_id="isnet-general-use",
46
+ provider="local",
47
+ operation_type="bg_removal",
48
+ description="Improved segmentation model - highest quality, slower",
49
+ estimated_time_ms=678, # Updated from benchmark results
50
+ is_default=False,
51
+ default_parameters={},
52
+ ),
53
+ "u2net": ModelConfig(
54
+ name="u2net",
55
+ display_name="U2Net (Local)",
56
+ model_id="u2net",
57
+ provider="local",
58
+ operation_type="bg_removal",
59
+ description="General purpose background removal - BEST SPEED/QUALITY BALANCE",
60
+ estimated_time_ms=416, # Updated from benchmark results
61
+ is_default=True, # Updated based on Phase 2 benchmark results
62
+ default_parameters={},
63
+ ),
64
+ "u2net-human": ModelConfig(
65
+ name="u2net-human",
66
+ display_name="U2Net Human (Local)",
67
+ model_id="u2net_human_seg",
68
+ provider="local",
69
+ operation_type="bg_removal",
70
+ description="Optimized for human segmentation and portraits",
71
+ estimated_time_ms=436, # Updated from benchmark results
72
+ is_default=False, # Changed: u2net outperformed in benchmarks
73
+ default_parameters={},
74
+ ),
75
+ "isnet-anime": ModelConfig(
76
+ name="isnet-anime",
77
+ display_name="ISNet Anime (Local)",
78
+ model_id="isnet-anime",
79
+ provider="local",
80
+ operation_type="bg_removal",
81
+ description="Specialized for anime/illustrations - not suitable for medical photos",
82
+ estimated_time_ms=727, # Updated from benchmark results
83
+ is_default=False,
84
+ default_parameters={},
85
+ ),
86
+ # HuggingFace models (currently unavailable)
87
+ "rmbg-1.4": ModelConfig(
88
+ name="rmbg-1.4",
89
+ display_name="RMBG 1.4 (HF)",
90
+ model_id="briaai/RMBG-1.4",
91
+ operation_type="bg_removal",
92
+ description="Fast and accurate background removal by Bria AI (v1.4) - Currently unavailable",
93
+ version="1.4",
94
+ estimated_time_ms=2000,
95
+ is_default=False,
96
+ is_active=False, # Disabled due to HF API deprecation
97
+ default_parameters={},
98
+ ),
99
+ }
100
+
101
+ # ============================================================================
102
+ # Upscaling Models
103
+ # ============================================================================
104
+
105
+ UPSCALING_MODELS = {
106
+ "realesrgan": ModelConfig(
107
+ name="realesrgan",
108
+ display_name="Real-ESRGAN 4x",
109
+ model_id="ai-forever/Real-ESRGAN",
110
+ operation_type="upscale",
111
+ description="Real-ESRGAN 4x upscaling for general images",
112
+ estimated_time_ms=8000,
113
+ is_default=True,
114
+ requires_gpu=True,
115
+ default_parameters={"scale": 4},
116
+ ),
117
+ "gfpgan": ModelConfig(
118
+ name="gfpgan",
119
+ display_name="GFPGAN",
120
+ model_id="TencentARC/GFPGAN",
121
+ operation_type="upscale",
122
+ description="Face restoration and enhancement",
123
+ estimated_time_ms=7000,
124
+ is_default=False,
125
+ requires_gpu=True,
126
+ default_parameters={"version": "1.3"},
127
+ ),
128
+ "codeformer": ModelConfig(
129
+ name="codeformer",
130
+ display_name="CodeFormer",
131
+ model_id="sczhou/CodeFormer",
132
+ operation_type="upscale",
133
+ description="Face restoration with better quality preservation",
134
+ estimated_time_ms=9000,
135
+ is_default=False,
136
+ requires_gpu=True,
137
+ default_parameters={"w": 0.5},
138
+ ),
139
+ "swin2sr": ModelConfig(
140
+ name="swin2sr",
141
+ display_name="Swin2SR",
142
+ model_id="caidas/swin2SR-classical-sr-x4-64",
143
+ operation_type="upscale",
144
+ description="Swin Transformer for image super-resolution",
145
+ estimated_time_ms=6000,
146
+ is_default=False,
147
+ requires_gpu=True,
148
+ default_parameters={"scale": 4},
149
+ ),
150
+ }
151
+
152
+ # ============================================================================
153
+ # Combined Registry
154
+ # ============================================================================
155
+
156
+ MODEL_REGISTRY = {
157
+ **BACKGROUND_REMOVAL_MODELS,
158
+ **UPSCALING_MODELS,
159
+ }
160
+
161
+
162
+ # ============================================================================
163
+ # Helper Functions
164
+ # ============================================================================
165
+
166
+
167
+ def get_model(name: str) -> Optional[ModelConfig]:
168
+ """Get model configuration by name."""
169
+ return MODEL_REGISTRY.get(name)
170
+
171
+
172
+ def get_models_by_type(operation_type: str) -> Dict[str, ModelConfig]:
173
+ """Get all models for a specific operation type."""
174
+ return {
175
+ name: config
176
+ for name, config in MODEL_REGISTRY.items()
177
+ if config.operation_type == operation_type and config.is_active
178
+ }
179
+
180
+
181
+ def get_default_model(operation_type: str) -> Optional[ModelConfig]:
182
+ """Get the default model for an operation type."""
183
+ models = get_models_by_type(operation_type)
184
+ for model in models.values():
185
+ if model.is_default:
186
+ return model
187
+ # If no default set, return first active model
188
+ return next(iter(models.values())) if models else None
189
+
190
+
191
+ def list_all_models() -> Dict[str, ModelConfig]:
192
+ """Get all active models."""
193
+ return {name: config for name, config in MODEL_REGISTRY.items() if config.is_active}
194
+
195
+
196
+ def list_model_names(operation_type: Optional[str] = None) -> List[str]:
197
+ """Get list of model names, optionally filtered by operation type."""
198
+ if operation_type:
199
+ return list(get_models_by_type(operation_type).keys())
200
+ return list(MODEL_REGISTRY.keys())
201
+
202
+
203
+ # ============================================================================
204
+ # Model Information Display
205
+ # ============================================================================
206
+
207
+
208
+ def get_model_info(name: str) -> str:
209
+ """Get formatted information about a model."""
210
+ model = get_model(name)
211
+ if not model:
212
+ return f"Model '{name}' not found in registry."
213
+
214
+ info = f"""
215
+ Model: {model.display_name}
216
+ ID: {model.model_id}
217
+ Type: {model.operation_type}
218
+ Provider: {model.provider}
219
+ Description: {model.description}
220
+ Est. Time: {model.estimated_time_ms}ms
221
+ Default: {'Yes' if model.is_default else 'No'}
222
+ Active: {'Yes' if model.is_active else 'No'}
223
+ """.strip()
224
+
225
+ return info
226
+
227
+
228
+ if __name__ == "__main__":
229
+ # Demo/test code
230
+ print("=== WARP Model Registry ===\n")
231
+
232
+ print("Background Removal Models:")
233
+ for name in list_model_names("bg_removal"):
234
+ print(f" - {name}")
235
+
236
+ print("\nUpscaling Models:")
237
+ for name in list_model_names("upscale"):
238
+ print(f" - {name}")
239
+
240
+ print("\nDefault Background Removal Model:")
241
+ default_bg = get_default_model("bg_removal")
242
+ if default_bg:
243
+ print(f" {default_bg.display_name} ({default_bg.name})")
244
+
245
+ print("\nDefault Upscaling Model:")
246
+ default_up = get_default_model("upscale")
247
+ if default_up:
248
+ print(f" {default_up.display_name} ({default_up.name})")