ginipick commited on
Commit
ab208ea
·
verified ·
1 Parent(s): 4630bc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -428
app.py CHANGED
@@ -1,431 +1,35 @@
1
- import gradio as gr
2
- import spaces
3
- import torch
4
- import diffusers
5
- import transformers
6
- import copy
7
- import random
8
- import numpy as np
9
- import torchvision.transforms as T
10
- import math
11
  import os
12
- import peft
13
- from peft import LoraConfig
14
- from safetensors import safe_open
15
- from omegaconf import OmegaConf
16
- from omnitry.models.transformer_flux import FluxTransformer2DModel
17
- from omnitry.pipelines.pipeline_flux_fill import FluxFillPipeline
18
- from PIL import Image
19
-
20
- from huggingface_hub import snapshot_download
21
- snapshot_download(repo_id="Kunbyte/OmniTry", local_dir="./OmniTry")
22
-
23
- device = torch.device('cuda:0')
24
- weight_dtype = torch.bfloat16
25
- args = OmegaConf.load('configs/omnitry_v1_unified.yaml')
26
-
27
- # init model
28
- transformer = FluxTransformer2DModel.from_pretrained(
29
- 'black-forest-labs/FLUX.1-Fill-dev',
30
- subfolder='transformer'
31
- ).requires_grad_(False).to(device, dtype=weight_dtype)
32
-
33
- pipeline = FluxFillPipeline.from_pretrained(
34
- 'black-forest-labs/FLUX.1-Fill-dev',
35
- transformer=transformer,
36
- torch_dtype=weight_dtype
37
- ).to(device)
38
-
39
- # insert LoRA
40
- lora_config = LoraConfig(
41
- r=args.lora_rank,
42
- lora_alpha=args.lora_alpha,
43
- init_lora_weights="gaussian",
44
- target_modules=[
45
- 'x_embedder',
46
- 'attn.to_k', 'attn.to_q', 'attn.to_v', 'attn.to_out.0',
47
- 'attn.add_k_proj', 'attn.add_q_proj', 'attn.add_v_proj', 'attn.to_add_out',
48
- 'ff.net.0.proj', 'ff.net.2', 'ff_context.net.0.proj', 'ff_context.net.2',
49
- 'norm1_context.linear', 'norm1.linear', 'norm.linear', 'proj_mlp', 'proj_out'
50
- ]
51
- )
52
- transformer.add_adapter(lora_config, adapter_name='vtryon_lora')
53
- transformer.add_adapter(lora_config, adapter_name='garment_lora')
54
-
55
- with safe_open('OmniTry/omnitry_v1_unified.safetensors', framework="pt") as f:
56
- lora_weights = {k: f.get_tensor(k) for k in f.keys()}
57
- transformer.load_state_dict(lora_weights, strict=False)
58
-
59
- # hack lora forward
60
- def create_hacked_forward(module):
61
-
62
- def lora_forward(self, active_adapter, x, *args, **kwargs):
63
- result = self.base_layer(x, *args, **kwargs)
64
- if active_adapter is not None:
65
- lora_A = self.lora_A[active_adapter]
66
- lora_B = self.lora_B[active_adapter]
67
- dropout = self.lora_dropout[active_adapter]
68
- scaling = self.scaling[active_adapter]
69
- x = x.to(lora_A.weight.dtype)
70
- result = result + lora_B(lora_A(dropout(x))) * scaling
71
- return result
72
-
73
- def hacked_lora_forward(self, x, *args, **kwargs):
74
- return torch.cat((
75
- lora_forward(self, 'vtryon_lora', x[:1], *args, **kwargs),
76
- lora_forward(self, 'garment_lora', x[1:], *args, **kwargs),
77
- ), dim=0)
78
-
79
- return hacked_lora_forward.__get__(module, type(module))
80
-
81
- for n, m in transformer.named_modules():
82
- if isinstance(m, peft.tuners.lora.layer.Linear):
83
- m.forward = create_hacked_forward(m)
84
-
85
-
86
- def seed_everything(seed=0):
87
- random.seed(seed)
88
- os.environ['PYTHONHASHSEED'] = str(seed)
89
- np.random.seed(seed)
90
- torch.manual_seed(seed)
91
- torch.cuda.manual_seed(seed)
92
- torch.cuda.manual_seed_all(seed)
93
-
94
-
95
- # Category mapping with sample images
96
- CATEGORY_SAMPLES = {
97
- 'top': 'top.png',
98
- 'bottom': 'bottom.png',
99
- 'dress': 'dress.jpg',
100
- 'hat': 'hat.png',
101
- 'sunglasses': 'sunglasses.png',
102
- 'glasses': 'glasses.png',
103
- 'necklace': 'necklace.png',
104
- 'earrings': 'earrings.png',
105
- 'bracelet': 'bracelet.png',
106
- 'ring': 'ring.png',
107
- 'tie': 'tie.png',
108
- 'bow tie': 'bow tie.png',
109
- 'belt': 'belt.png',
110
- 'shoe': 'shoe.png',
111
- 'bag': 'bag.png'
112
- }
113
-
114
- # Person sample images
115
- PERSON_SAMPLES = ['woman.png', 'man.png']
116
-
117
- def load_sample_image(category):
118
- """Load sample image for selected category"""
119
- if category and category in CATEGORY_SAMPLES:
120
- img_path = CATEGORY_SAMPLES[category]
121
- if os.path.exists(img_path):
122
- return Image.open(img_path)
123
- return None
124
-
125
- def load_random_person():
126
- """Load random person image on initialization"""
127
- person_img = random.choice(PERSON_SAMPLES)
128
- if os.path.exists(person_img):
129
- return Image.open(person_img)
130
- return None
131
-
132
- def create_category_html():
133
- """Create HTML for category thumbnails"""
134
- html = '<div class="category-container">'
135
- for category, img_file in CATEGORY_SAMPLES.items():
136
- if os.path.exists(img_file):
137
- # Create base64 encoded thumbnail for each category
138
- import base64
139
- from io import BytesIO
140
- img = Image.open(img_file)
141
- # Resize for thumbnail
142
- img.thumbnail((80, 80), Image.Resampling.LANCZOS)
143
- buffered = BytesIO()
144
- img.save(buffered, format="PNG")
145
- img_str = base64.b64encode(buffered.getvalue()).decode()
146
-
147
- html += f'''
148
- <div class="category-item" data-category="{category}">
149
- <img src="data:image/png;base64,{img_str}" alt="{category}">
150
- <span>{category.title()}</span>
151
- </div>
152
- '''
153
- html += '</div>'
154
- return html
155
-
156
-
157
- @spaces.GPU
158
- def generate(person_image, object_image, object_class, steps, guidance_scale, seed):
159
- if seed == -1:
160
- seed = random.randint(0, 2**32 - 1)
161
- seed_everything(seed)
162
-
163
- max_area = 1024 * 1024
164
- oW, oH = person_image.width, person_image.height
165
- ratio = math.sqrt(max_area / (oW * oH))
166
- ratio = min(1, ratio)
167
- tW, tH = int(oW * ratio) // 16 * 16, int(oH * ratio) // 16 * 16
168
- transform = T.Compose([
169
- T.Resize((tH, tW)),
170
- T.ToTensor(),
171
- ])
172
- person_image = transform(person_image)
173
-
174
- ratio = min(tW / object_image.width, tH / object_image.height)
175
- transform = T.Compose([
176
- T.Resize((int(object_image.height * ratio), int(object_image.width * ratio))),
177
- T.ToTensor(),
178
- ])
179
- object_image_padded = torch.ones_like(person_image)
180
- object_image = transform(object_image)
181
- new_h, new_w = object_image.shape[1], object_image.shape[2]
182
- min_x = (tW - new_w) // 2
183
- min_y = (tH - new_h) // 2
184
- object_image_padded[:, min_y: min_y + new_h, min_x: min_x + new_w] = object_image
185
-
186
- prompts = [args.object_map[object_class]] * 2
187
- img_cond = torch.stack([person_image, object_image_padded]).to(dtype=weight_dtype, device=device)
188
- mask = torch.zeros_like(img_cond).to(img_cond)
189
-
190
- with torch.no_grad():
191
- img = pipeline(
192
- prompt=prompts,
193
- height=tH,
194
- width=tW,
195
- img_cond=img_cond,
196
- mask=mask,
197
- guidance_scale=guidance_scale,
198
- num_inference_steps=steps,
199
- generator=torch.Generator(device).manual_seed(seed),
200
- ).images[0]
201
-
202
- return img
203
-
204
-
205
- # Custom CSS
206
- custom_css = """
207
- /* 전체 배경 */
208
- .gradio-container {
209
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
210
- font-family: 'Inter', sans-serif;
211
- }
212
-
213
- /* Category thumbnails container */
214
- .category-container {
215
- display: flex;
216
- flex-wrap: wrap;
217
- gap: 15px;
218
- justify-content: center;
219
- margin: 20px 0;
220
- padding: 20px;
221
- background: rgba(255, 255, 255, 0.1);
222
- border-radius: 15px;
223
- backdrop-filter: blur(10px);
224
- }
225
-
226
- .category-item {
227
- display: flex;
228
- flex-direction: column;
229
- align-items: center;
230
- cursor: pointer;
231
- padding: 10px;
232
- border-radius: 10px;
233
- background: rgba(255, 255, 255, 0.9);
234
- transition: all 0.3s ease;
235
- min-width: 90px;
236
- }
237
-
238
- .category-item:hover {
239
- transform: translateY(-5px);
240
- box-shadow: 0 5px 20px rgba(0, 0, 0, 0.2);
241
- background: white;
242
- }
243
-
244
- .category-item.selected {
245
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
246
- color: white;
247
- }
248
-
249
- .category-item img {
250
- width: 60px;
251
- height: 60px;
252
- object-fit: contain;
253
- margin-bottom: 5px;
254
- }
255
-
256
- .category-item span {
257
- font-size: 0.85em;
258
- text-align: center;
259
- font-weight: 500;
260
- }
261
-
262
- /* === 플레이스홀더 전부 제거 === */
263
- .gr-image svg,
264
- .gr-image [data-testid*="placeholder"],
265
- .gr-image [class*="placeholder"],
266
- .gr-image [aria-label*="placeholder"],
267
- .gr-image [class*="svelte"][class*="placeholder"],
268
- .gr-image .absolute.inset-0.flex.items-center.justify-center,
269
- .gr-image .flex.items-center.justify-center svg {
270
- display: none !important;
271
- visibility: hidden !important;
272
- }
273
- .gr-image [class*="overlay"],
274
- .gr-image .fixed.inset-0,
275
- .gr-image .absolute.inset-0 {
276
- pointer-events: none !important;
277
- }
278
-
279
- /* 이미지 업로드 영역 */
280
- .gr-image .wrap { background: transparent !important; min-height: 400px !important; }
281
- .gr-image .upload-container {
282
- min-height: 400px !important;
283
- border: 3px dashed rgba(102, 126, 234, 0.4) !important;
284
- border-radius: 12px !important;
285
- background: linear-gradient(135deg, rgba(248, 250, 252, 0.5) 0%, rgba(241, 245, 249, 0.5) 100%) !important;
286
- position: relative !important;
287
- }
288
- /* 이미지 있을 때 */
289
- .gr-image:has(img) .upload-container { border: none !important; background: transparent !important; }
290
- /* 안내 텍스트 */
291
- .gr-image .upload-container::after {
292
- content: "Click or Drag to Upload";
293
- position: absolute; top: 50%; left: 50%;
294
- transform: translate(-50%, -50%);
295
- color: rgba(102, 126, 234, 0.7);
296
- font-size: 1.05em; font-weight: 500;
297
- pointer-events: none;
298
- }
299
- .gr-image:has(img) .upload-container::after { display: none !important; }
300
- /* 업로드 이미지 */
301
- .gr-image img { border-radius: 12px !important; position: relative !important; z-index: 10 !important; }
302
-
303
- /* 버튼 스타일 */
304
- .gr-button-primary {
305
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
306
- color: white !important; border: none !important;
307
- padding: 15px 40px !important; font-size: 1.2em !important;
308
- border-radius: 50px !important; cursor: pointer !important;
309
- }
310
-
311
- /* Radio button styling for categories */
312
- .gr-radio {
313
- display: none !important;
314
- }
315
- """
316
-
317
- if __name__ == '__main__':
318
-
319
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
320
- with gr.Column(elem_id="header"):
321
- gr.HTML("""
322
- <h1 style="text-align: center; color: white; margin: 20px 0;">
323
- ✨ CodiFit-AI Virtual Try-On ✨
324
- </h1>
325
- <p style="text-align: center; color: rgba(255,255,255,0.9); margin-bottom: 30px;">
326
- Experience the future of fashion with AI-powered virtual clothing try-on
327
- </p>
328
- """)
329
-
330
- # Category selector with thumbnails
331
- gr.Markdown("### Select Fashion Category")
332
- with gr.Row():
333
- category_selector = gr.Radio(
334
- choices=list(CATEGORY_SAMPLES.keys()),
335
- value='top',
336
- label="Fashion Categories",
337
- elem_id="category_radio",
338
- visible=False # Hide default radio buttons
339
- )
340
-
341
- # Display category thumbnails
342
- category_html = gr.HTML(create_category_html())
343
-
344
- with gr.Row(equal_height=True):
345
- with gr.Column(scale=1):
346
- person_image = gr.Image(
347
- type="pil",
348
- label="Upload Person Photo",
349
- height=500,
350
- interactive=True,
351
- value=load_random_person() # Load random person on init
352
- )
353
-
354
- with gr.Column(scale=1):
355
- object_image = gr.Image(
356
- type="pil",
357
- label="Upload Object Image",
358
- height=400,
359
- interactive=True,
360
- value=load_sample_image('top') # Load default top image
361
- )
362
- object_class = gr.Dropdown(
363
- label='Selected Object Category',
364
- choices=list(args.object_map.keys()),
365
- value='top',
366
- interactive=True
367
- )
368
- run_button = gr.Button(value="🚀 Generate Try-On", variant='primary')
369
-
370
- with gr.Column(scale=1):
371
- image_out = gr.Image(type="pil", label="Virtual Try-On Result", height=500, interactive=False)
372
-
373
- with gr.Accordion("⚙️ Advanced Settings", open=False):
374
- with gr.Row():
375
- guidance_scale = gr.Slider(label="🎯 Guidance Scale", minimum=1, maximum=50, value=30, step=0.1)
376
- steps = gr.Slider(label="🔄 Inference Steps", minimum=1, maximum=50, value=20, step=1)
377
- seed = gr.Number(label="🎲 Random Seed", value=-1, precision=0)
378
-
379
- # JavaScript for category selection interaction
380
- demo.load(None, None, None, js="""
381
- function() {
382
- // Add click handlers to category items
383
- document.querySelectorAll('.category-item').forEach(item => {
384
- item.addEventListener('click', function() {
385
- // Remove selected class from all items
386
- document.querySelectorAll('.category-item').forEach(i =>
387
- i.classList.remove('selected')
388
- );
389
- // Add selected class to clicked item
390
- this.classList.add('selected');
391
-
392
- // Get category name
393
- const category = this.dataset.category;
394
-
395
- // Trigger category selection (would need proper Gradio event handling)
396
- console.log('Selected category:', category);
397
- });
398
- });
399
 
400
- // Select first category by default
401
- document.querySelector('.category-item')?.classList.add('selected');
402
- }
403
- """)
404
-
405
- # Handle category selection
406
- def on_category_select(category):
407
- """Update object image and class when category is selected"""
408
- sample_img = load_sample_image(category)
409
- return sample_img, category
410
-
411
- # Connect category selector to update object image and class
412
- category_selector.change(
413
- fn=on_category_select,
414
- inputs=[category_selector],
415
- outputs=[object_image, object_class]
416
- )
417
-
418
- # Manual category selection through radio (hidden but functional)
419
- def update_from_thumbnail(evt: gr.SelectData):
420
- category = list(CATEGORY_SAMPLES.keys())[evt.index]
421
- sample_img = load_sample_image(category)
422
- return sample_img, category
423
-
424
- # Run generation
425
- run_button.click(
426
- generate,
427
- inputs=[person_image, object_image, object_class, steps, guidance_scale, seed],
428
- outputs=[image_out]
429
- )
430
 
431
- demo.launch(server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import sys
3
+ import streamlit as st
4
+ from tempfile import NamedTemporaryFile
5
+
6
+ def main():
7
+ try:
8
+ # Get the code from secrets
9
+ code = os.environ.get("MAIN_CODE")
10
+
11
+ if not code:
12
+ st.error("⚠️ The application code wasn't found in secrets. Please add the MAIN_CODE secret.")
13
+ return
14
+
15
+ # Create a temporary Python file
16
+ with NamedTemporaryFile(suffix='.py', delete=False, mode='w') as tmp:
17
+ tmp.write(code)
18
+ tmp_path = tmp.name
19
+
20
+ # Execute the code
21
+ exec(compile(code, tmp_path, 'exec'), globals())
22
+
23
+ # Clean up the temporary file
24
+ try:
25
+ os.unlink(tmp_path)
26
+ except:
27
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ except Exception as e:
30
+ st.error(f"⚠️ Error loading or executing the application: {str(e)}")
31
+ import traceback
32
+ st.code(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ if __name__ == "__main__":
35
+ main()