prithivMLmods commited on
Commit
fdd9313
·
verified ·
1 Parent(s): ee4b2e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -19
app.py CHANGED
@@ -1,4 +1,10 @@
1
  import os
 
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import spaces
@@ -9,7 +15,7 @@ from typing import Iterable
9
  from gradio.themes import Soft
10
  from gradio.themes.utils import colors, fonts, sizes
11
 
12
- # --- Theme Configuration ---
13
  colors.orange_red = colors.Color(
14
  name="orange_red",
15
  c50="#FFF0E5",
@@ -78,7 +84,6 @@ class OrangeRedTheme(Soft):
78
 
79
  orange_red_theme = OrangeRedTheme()
80
 
81
- # --- Device Setup ---
82
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
  dtype = torch.bfloat16
84
 
@@ -87,7 +92,6 @@ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
87
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
88
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
89
 
90
- # --- Model Loading ---
91
  print("Loading Qwen Image Edit Pipeline...")
92
  pipe = QwenImageEditPlusPipeline.from_pretrained(
93
  "Qwen/Qwen-Image-Edit-2509",
@@ -100,6 +104,13 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
100
  torch_dtype=dtype
101
  ).to(device)
102
 
 
 
 
 
 
 
 
103
  print("Loading and Fusing Lightning LoRA...")
104
  pipe.load_lora_weights("lightx2v/Qwen-Image-Lightning",
105
  weight_name="Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors",
@@ -173,6 +184,10 @@ def infer(
173
  steps,
174
  progress=gr.Progress(track_tqdm=True)
175
  ):
 
 
 
 
176
  if image_1 is None or image_2 is None:
177
  raise gr.Error("Please upload both images for Fusion/Texture/FaceSwap tasks.")
178
 
@@ -217,13 +232,9 @@ def infer(
217
 
218
  width, height = update_dimensions_on_upload(img1_pil)
219
 
220
- # --- Fix: Explicit Memory Management ---
221
- # Clear cache before starting the heavy inference process
222
- torch.cuda.empty_cache()
223
-
224
  try:
225
- # Use no_grad to prevent gradient calculation and save memory
226
- with torch.no_grad():
227
  result = pipe(
228
  image=[img1_pil, img2_pil],
229
  prompt=prompt,
@@ -234,24 +245,24 @@ def infer(
234
  generator=generator,
235
  true_cfg_scale=guidance_scale,
236
  ).images[0]
 
 
 
237
  except Exception as e:
238
- # If an error occurs, ensure we still clear cache before raising
239
- torch.cuda.empty_cache()
240
  raise e
241
-
242
- # Clear cache after inference is done
243
- torch.cuda.empty_cache()
244
-
245
- return result, seed
246
 
247
  @spaces.GPU
248
  def infer_example(image_1, image_2, prompt, lora_adapter):
249
  if image_1 is None or image_2 is None:
250
  return None, 0
251
 
252
- # Optional: Clear cache before example inference as well
253
- torch.cuda.empty_cache()
254
-
255
  result, seed = infer(
256
  image_1.convert("RGB"),
257
  image_2.convert("RGB"),
 
1
  import os
2
+ import gc
3
+
4
+ # 1. FIX: Set memory allocation configuration BEFORE importing torch
5
+ # 'expandable_segments:True' prevents the specific CUDACachingAllocator assertion failure
6
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
7
+
8
  import gradio as gr
9
  import numpy as np
10
  import spaces
 
15
  from gradio.themes import Soft
16
  from gradio.themes.utils import colors, fonts, sizes
17
 
18
+ # Define Theme
19
  colors.orange_red = colors.Color(
20
  name="orange_red",
21
  c50="#FFF0E5",
 
84
 
85
  orange_red_theme = OrangeRedTheme()
86
 
 
87
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
  dtype = torch.bfloat16
89
 
 
92
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
93
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
94
 
 
95
  print("Loading Qwen Image Edit Pipeline...")
96
  pipe = QwenImageEditPlusPipeline.from_pretrained(
97
  "Qwen/Qwen-Image-Edit-2509",
 
104
  torch_dtype=dtype
105
  ).to(device)
106
 
107
+ # 2. FIX: Enable VAE Tiling. This is crucial for decoding large images without OOM.
108
+ try:
109
+ pipe.enable_vae_tiling()
110
+ print("VAE Tiling enabled.")
111
+ except Exception as e:
112
+ print(f"Warning: Could not enable VAE tiling: {e}")
113
+
114
  print("Loading and Fusing Lightning LoRA...")
115
  pipe.load_lora_weights("lightx2v/Qwen-Image-Lightning",
116
  weight_name="Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors",
 
184
  steps,
185
  progress=gr.Progress(track_tqdm=True)
186
  ):
187
+ # 3. FIX: Aggressive Garbage Collection before run
188
+ gc.collect()
189
+ torch.cuda.empty_cache()
190
+
191
  if image_1 is None or image_2 is None:
192
  raise gr.Error("Please upload both images for Fusion/Texture/FaceSwap tasks.")
193
 
 
232
 
233
  width, height = update_dimensions_on_upload(img1_pil)
234
 
 
 
 
 
235
  try:
236
+ # 3. FIX: Use inference_mode for better memory efficiency
237
+ with torch.inference_mode():
238
  result = pipe(
239
  image=[img1_pil, img2_pil],
240
  prompt=prompt,
 
245
  generator=generator,
246
  true_cfg_scale=guidance_scale,
247
  ).images[0]
248
+
249
+ return result, seed
250
+
251
  except Exception as e:
252
+ # Rethrow so Gradio sees the error, but allow finally block to run
 
253
  raise e
254
+
255
+ finally:
256
+ # 3. FIX: Cleanup after run regardless of success or failure
257
+ gc.collect()
258
+ torch.cuda.empty_cache()
259
 
260
  @spaces.GPU
261
  def infer_example(image_1, image_2, prompt, lora_adapter):
262
  if image_1 is None or image_2 is None:
263
  return None, 0
264
 
265
+ # Simple wrapper call
 
 
266
  result, seed = infer(
267
  image_1.convert("RGB"),
268
  image_2.convert("RGB"),