Spaces:
Runtime error
Runtime error
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -5,11 +5,12 @@ from src.flux.condition import Condition
|
|
| 5 |
from PIL import Image
|
| 6 |
import argparse
|
| 7 |
import os
|
|
|
|
| 8 |
import json
|
| 9 |
import base64
|
| 10 |
import io
|
| 11 |
import re
|
| 12 |
-
from PIL import
|
| 13 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
from scipy.ndimage import binary_dilation
|
| 15 |
import cv2
|
|
@@ -27,6 +28,31 @@ except ImportError:
|
|
| 27 |
|
| 28 |
import re
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def encode_image_to_datauri(path, size=(512, 512)):
|
| 31 |
with Image.open(path).convert('RGB') as img:
|
| 32 |
img = img.resize(size, Image.LANCZOS)
|
|
@@ -34,8 +60,6 @@ def encode_image_to_datauri(path, size=(512, 512)):
|
|
| 34 |
img.save(buffer, format='PNG')
|
| 35 |
b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
| 36 |
return b64
|
| 37 |
-
# return f"data:image/png;base64,{b64}"
|
| 38 |
-
|
| 39 |
|
| 40 |
@retry(
|
| 41 |
reraise=True,
|
|
@@ -93,7 +117,6 @@ def cot_with_gpt(image_uri, instruction):
|
|
| 93 |
categories, instructions = extract_instructions(text)
|
| 94 |
return categories, instructions
|
| 95 |
|
| 96 |
-
|
| 97 |
def extract_instructions(text):
|
| 98 |
categories = []
|
| 99 |
instructions = []
|
|
@@ -134,9 +157,9 @@ def extract_last_bbox(result):
|
|
| 134 |
x0, y0, x1, y1 = map(int, last_match[1:])
|
| 135 |
return x0, y0, x1, y1
|
| 136 |
|
| 137 |
-
|
| 138 |
def infer_with_DiT(task, image, instruction, category):
|
| 139 |
-
|
| 140 |
|
| 141 |
if task == 'RoI Inpainting':
|
| 142 |
if category == 'Add' or category == 'Replace':
|
|
@@ -180,18 +203,14 @@ def infer_with_DiT(task, image, instruction, category):
|
|
| 180 |
condition = Condition("scene", image, position_delta=(0, -32))
|
| 181 |
else:
|
| 182 |
raise ValueError(f"Invalid task: '{task}'")
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
torch_dtype=torch.bfloat16
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
pipe = pipe.to("cuda")
|
| 189 |
-
|
| 190 |
pipe.load_lora_weights(
|
| 191 |
"Cicici1109/IEAP",
|
| 192 |
weight_name=lora_path,
|
| 193 |
adapter_name="scene",
|
| 194 |
)
|
|
|
|
| 195 |
result_img = generate(
|
| 196 |
pipe,
|
| 197 |
prompt=instruction_dit,
|
|
@@ -201,15 +220,13 @@ def infer_with_DiT(task, image, instruction, category):
|
|
| 201 |
height=512,
|
| 202 |
width=512,
|
| 203 |
).images[0]
|
| 204 |
-
|
| 205 |
if task == 'RoI Editing' and category == 'Action Change':
|
| 206 |
text_roi = extract_object_with_gpt(instruction)
|
| 207 |
instruction_loc = f"<image>Please segment {text_roi}."
|
| 208 |
-
# (model, tokenizer, image_path, instruction, work_dir, dilate):
|
| 209 |
img = result_img
|
| 210 |
-
# print(f"Instruction: {instruction_loc}")
|
| 211 |
|
| 212 |
-
model, tokenizer =
|
| 213 |
|
| 214 |
result = model.predict_forward(
|
| 215 |
image=img,
|
|
@@ -218,13 +235,11 @@ def infer_with_DiT(task, image, instruction, category):
|
|
| 218 |
)
|
| 219 |
|
| 220 |
prediction = result['prediction']
|
| 221 |
-
# print(f"Model Output: {prediction}")
|
| 222 |
|
| 223 |
if '[SEG]' in prediction and 'prediction_masks' in result:
|
| 224 |
pred_mask = result['prediction_masks'][0]
|
| 225 |
pred_mask_np = np.squeeze(np.array(pred_mask))
|
| 226 |
|
| 227 |
-
## obtain region bbox
|
| 228 |
rows = np.any(pred_mask_np, axis=1)
|
| 229 |
cols = np.any(pred_mask_np, axis=0)
|
| 230 |
if not np.any(rows) or not np.any(cols):
|
|
@@ -238,18 +253,10 @@ def infer_with_DiT(task, image, instruction, category):
|
|
| 238 |
|
| 239 |
return changed_instance, x0, y1, 1
|
| 240 |
|
| 241 |
-
|
| 242 |
return result_img
|
| 243 |
|
| 244 |
def load_model(model_path):
|
| 245 |
-
|
| 246 |
-
model_path,
|
| 247 |
-
torch_dtype="auto",
|
| 248 |
-
device_map="auto",
|
| 249 |
-
trust_remote_code=True
|
| 250 |
-
).eval()
|
| 251 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 252 |
-
return model, tokenizer
|
| 253 |
|
| 254 |
def extract_object_with_gpt(instruction):
|
| 255 |
system_prompt = (
|
|
@@ -304,7 +311,6 @@ def extract_region_with_gpt(instruction):
|
|
| 304 |
max_tokens=20,
|
| 305 |
)
|
| 306 |
object_phrase = response.choices[0].message['content'].strip().strip('"')
|
| 307 |
-
# print(f"Identified object: {object_phrase}")
|
| 308 |
return object_phrase
|
| 309 |
except Exception as e:
|
| 310 |
print(f"GPT extraction failed: {e}")
|
|
@@ -372,8 +378,9 @@ def crop_masked_region(image, pred_mask_np):
|
|
| 372 |
|
| 373 |
return Image.fromarray(cropped_image, mode='RGBA')
|
| 374 |
|
| 375 |
-
|
| 376 |
-
|
|
|
|
| 377 |
if category == 'Add':
|
| 378 |
text_roi = extract_region_with_gpt(instruction)
|
| 379 |
else:
|
|
@@ -389,13 +396,11 @@ def roi_localization(image, instruction, category): # add, remove, replace, acti
|
|
| 389 |
)
|
| 390 |
|
| 391 |
prediction = result['prediction']
|
| 392 |
-
# print(f"Model Output: {prediction}")
|
| 393 |
|
| 394 |
if '[SEG]' in prediction and 'prediction_masks' in result:
|
| 395 |
pred_mask = result['prediction_masks'][0]
|
| 396 |
pred_mask_np = np.squeeze(np.array(pred_mask))
|
| 397 |
if category == 'Add':
|
| 398 |
-
## obtain region bbox
|
| 399 |
rows = np.any(pred_mask_np, axis=1)
|
| 400 |
cols = np.any(pred_mask_np, axis=0)
|
| 401 |
if not np.any(rows) or not np.any(cols):
|
|
@@ -405,17 +410,14 @@ def roi_localization(image, instruction, category): # add, remove, replace, acti
|
|
| 405 |
y0, y1 = np.where(rows)[0][[0, -1]]
|
| 406 |
x0, x1 = np.where(cols)[0][[0, -1]]
|
| 407 |
|
| 408 |
-
|
| 409 |
-
bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
|
| 410 |
-
# print(bbox)
|
| 411 |
x0, y0, x1, y1 = layout_add(bbox, instruction)
|
| 412 |
mask = bbox_to_mask(x0, y0, x1, y1)
|
| 413 |
-
## make it black
|
| 414 |
masked_img = get_masked(mask, img)
|
| 415 |
elif category == 'Move' or category == 'Resize':
|
| 416 |
dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
|
| 417 |
masked_img = get_masked(dilated_original_mask, img)
|
| 418 |
-
|
| 419 |
rows = np.any(pred_mask_np, axis=1)
|
| 420 |
cols = np.any(pred_mask_np, axis=0)
|
| 421 |
if not np.any(rows) or not np.any(cols):
|
|
@@ -425,12 +427,10 @@ def roi_localization(image, instruction, category): # add, remove, replace, acti
|
|
| 425 |
y0, y1 = np.where(rows)[0][[0, -1]]
|
| 426 |
x0, x1 = np.where(cols)[0][[0, -1]]
|
| 427 |
|
| 428 |
-
|
| 429 |
-
bbox = combine_bbox(text_roi, x0, y0, x1, y1) #? multiple?
|
| 430 |
-
# print(bbox)
|
| 431 |
x0_new, y0_new, x1_new, y1_new, = layout_change(bbox, instruction)
|
| 432 |
scale = (y1_new - y0_new) / (y1 - y0)
|
| 433 |
-
|
| 434 |
changed_instance = crop_masked_region(img, pred_mask_np)
|
| 435 |
|
| 436 |
return masked_img, changed_instance, x0_new, y1_new, scale
|
|
@@ -588,4 +588,7 @@ def layout_change(bbox, instruction):
|
|
| 588 |
result = response.choices[0].message.content.strip()
|
| 589 |
|
| 590 |
bbox = extract_last_bbox(result)
|
| 591 |
-
return bbox
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
import argparse
|
| 7 |
import os
|
| 8 |
+
import spaces
|
| 9 |
import json
|
| 10 |
import base64
|
| 11 |
import io
|
| 12 |
import re
|
| 13 |
+
from PIL import ImageFilter
|
| 14 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 15 |
from scipy.ndimage import binary_dilation
|
| 16 |
import cv2
|
|
|
|
| 28 |
|
| 29 |
import re
|
| 30 |
|
| 31 |
+
pipe = None
|
| 32 |
+
model_dict = {}
|
| 33 |
+
|
| 34 |
+
def init_flux_pipeline():
|
| 35 |
+
global pipe
|
| 36 |
+
if pipe is None:
|
| 37 |
+
pipe = FluxPipeline.from_pretrained(
|
| 38 |
+
"black-forest-labs/FLUX.1-dev",
|
| 39 |
+
torch_dtype=torch.bfloat16
|
| 40 |
+
)
|
| 41 |
+
pipe = pipe.to("cuda")
|
| 42 |
+
|
| 43 |
+
def get_model(model_path):
|
| 44 |
+
global model_dict
|
| 45 |
+
if model_path not in model_dict:
|
| 46 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 47 |
+
model_path,
|
| 48 |
+
torch_dtype="auto",
|
| 49 |
+
device_map="auto",
|
| 50 |
+
trust_remote_code=True
|
| 51 |
+
).eval()
|
| 52 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 53 |
+
model_dict[model_path] = (model, tokenizer)
|
| 54 |
+
return model_dict[model_path]
|
| 55 |
+
|
| 56 |
def encode_image_to_datauri(path, size=(512, 512)):
|
| 57 |
with Image.open(path).convert('RGB') as img:
|
| 58 |
img = img.resize(size, Image.LANCZOS)
|
|
|
|
| 60 |
img.save(buffer, format='PNG')
|
| 61 |
b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
| 62 |
return b64
|
|
|
|
|
|
|
| 63 |
|
| 64 |
@retry(
|
| 65 |
reraise=True,
|
|
|
|
| 117 |
categories, instructions = extract_instructions(text)
|
| 118 |
return categories, instructions
|
| 119 |
|
|
|
|
| 120 |
def extract_instructions(text):
|
| 121 |
categories = []
|
| 122 |
instructions = []
|
|
|
|
| 157 |
x0, y0, x1, y1 = map(int, last_match[1:])
|
| 158 |
return x0, y0, x1, y1
|
| 159 |
|
| 160 |
+
@spaces.GPU
|
| 161 |
def infer_with_DiT(task, image, instruction, category):
|
| 162 |
+
init_flux_pipeline()
|
| 163 |
|
| 164 |
if task == 'RoI Inpainting':
|
| 165 |
if category == 'Add' or category == 'Replace':
|
|
|
|
| 203 |
condition = Condition("scene", image, position_delta=(0, -32))
|
| 204 |
else:
|
| 205 |
raise ValueError(f"Invalid task: '{task}'")
|
| 206 |
+
|
| 207 |
+
pipe.unload_lora_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
pipe.load_lora_weights(
|
| 209 |
"Cicici1109/IEAP",
|
| 210 |
weight_name=lora_path,
|
| 211 |
adapter_name="scene",
|
| 212 |
)
|
| 213 |
+
|
| 214 |
result_img = generate(
|
| 215 |
pipe,
|
| 216 |
prompt=instruction_dit,
|
|
|
|
| 220 |
height=512,
|
| 221 |
width=512,
|
| 222 |
).images[0]
|
| 223 |
+
|
| 224 |
if task == 'RoI Editing' and category == 'Action Change':
|
| 225 |
text_roi = extract_object_with_gpt(instruction)
|
| 226 |
instruction_loc = f"<image>Please segment {text_roi}."
|
|
|
|
| 227 |
img = result_img
|
|
|
|
| 228 |
|
| 229 |
+
model, tokenizer = get_model("ByteDance/Sa2VA-8B")
|
| 230 |
|
| 231 |
result = model.predict_forward(
|
| 232 |
image=img,
|
|
|
|
| 235 |
)
|
| 236 |
|
| 237 |
prediction = result['prediction']
|
|
|
|
| 238 |
|
| 239 |
if '[SEG]' in prediction and 'prediction_masks' in result:
|
| 240 |
pred_mask = result['prediction_masks'][0]
|
| 241 |
pred_mask_np = np.squeeze(np.array(pred_mask))
|
| 242 |
|
|
|
|
| 243 |
rows = np.any(pred_mask_np, axis=1)
|
| 244 |
cols = np.any(pred_mask_np, axis=0)
|
| 245 |
if not np.any(rows) or not np.any(cols):
|
|
|
|
| 253 |
|
| 254 |
return changed_instance, x0, y1, 1
|
| 255 |
|
|
|
|
| 256 |
return result_img
|
| 257 |
|
| 258 |
def load_model(model_path):
|
| 259 |
+
return get_model(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
def extract_object_with_gpt(instruction):
|
| 262 |
system_prompt = (
|
|
|
|
| 311 |
max_tokens=20,
|
| 312 |
)
|
| 313 |
object_phrase = response.choices[0].message['content'].strip().strip('"')
|
|
|
|
| 314 |
return object_phrase
|
| 315 |
except Exception as e:
|
| 316 |
print(f"GPT extraction failed: {e}")
|
|
|
|
| 378 |
|
| 379 |
return Image.fromarray(cropped_image, mode='RGBA')
|
| 380 |
|
| 381 |
+
@spaces.GPU
|
| 382 |
+
def roi_localization(image, instruction, category):
|
| 383 |
+
model, tokenizer = get_model("ByteDance/Sa2VA-8B")
|
| 384 |
if category == 'Add':
|
| 385 |
text_roi = extract_region_with_gpt(instruction)
|
| 386 |
else:
|
|
|
|
| 396 |
)
|
| 397 |
|
| 398 |
prediction = result['prediction']
|
|
|
|
| 399 |
|
| 400 |
if '[SEG]' in prediction and 'prediction_masks' in result:
|
| 401 |
pred_mask = result['prediction_masks'][0]
|
| 402 |
pred_mask_np = np.squeeze(np.array(pred_mask))
|
| 403 |
if category == 'Add':
|
|
|
|
| 404 |
rows = np.any(pred_mask_np, axis=1)
|
| 405 |
cols = np.any(pred_mask_np, axis=0)
|
| 406 |
if not np.any(rows) or not np.any(cols):
|
|
|
|
| 410 |
y0, y1 = np.where(rows)[0][[0, -1]]
|
| 411 |
x0, x1 = np.where(cols)[0][[0, -1]]
|
| 412 |
|
| 413 |
+
bbox = combine_bbox(text_roi, x0, y0, x1, y1)
|
|
|
|
|
|
|
| 414 |
x0, y0, x1, y1 = layout_add(bbox, instruction)
|
| 415 |
mask = bbox_to_mask(x0, y0, x1, y1)
|
|
|
|
| 416 |
masked_img = get_masked(mask, img)
|
| 417 |
elif category == 'Move' or category == 'Resize':
|
| 418 |
dilated_original_mask = binary_dilation(pred_mask_np, iterations=3)
|
| 419 |
masked_img = get_masked(dilated_original_mask, img)
|
| 420 |
+
|
| 421 |
rows = np.any(pred_mask_np, axis=1)
|
| 422 |
cols = np.any(pred_mask_np, axis=0)
|
| 423 |
if not np.any(rows) or not np.any(cols):
|
|
|
|
| 427 |
y0, y1 = np.where(rows)[0][[0, -1]]
|
| 428 |
x0, x1 = np.where(cols)[0][[0, -1]]
|
| 429 |
|
| 430 |
+
bbox = combine_bbox(text_roi, x0, y0, x1, y1)
|
|
|
|
|
|
|
| 431 |
x0_new, y0_new, x1_new, y1_new, = layout_change(bbox, instruction)
|
| 432 |
scale = (y1_new - y0_new) / (y1 - y0)
|
| 433 |
+
|
| 434 |
changed_instance = crop_masked_region(img, pred_mask_np)
|
| 435 |
|
| 436 |
return masked_img, changed_instance, x0_new, y1_new, scale
|
|
|
|
| 588 |
result = response.choices[0].message.content.strip()
|
| 589 |
|
| 590 |
bbox = extract_last_bbox(result)
|
| 591 |
+
return bbox
|
| 592 |
+
|
| 593 |
+
if __name__ == "__main__":
|
| 594 |
+
init_flux_pipeline()
|