Update demo/seagull_inference.py
Browse files- demo/seagull_inference.py +25 -12
demo/seagull_inference.py
CHANGED
|
@@ -13,6 +13,7 @@ import numpy as np
|
|
| 13 |
import cv2
|
| 14 |
from typing import List
|
| 15 |
from PIL import Image
|
|
|
|
| 16 |
|
| 17 |
class Seagull():
|
| 18 |
def __init__(self, model_path, device='cuda'):
|
|
@@ -40,9 +41,9 @@ class Seagull():
|
|
| 40 |
begin_str = "<image>\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region <global><local>. "
|
| 41 |
|
| 42 |
instruction = {
|
| 43 |
-
'distortion
|
| 44 |
-
'quality
|
| 45 |
-
'importance
|
| 46 |
}
|
| 47 |
|
| 48 |
self.ids_input = {}
|
|
@@ -70,7 +71,7 @@ class Seagull():
|
|
| 70 |
else:
|
| 71 |
preprocessed_img = img.copy()
|
| 72 |
|
| 73 |
-
return (preprocessed_img, preprocessed_img, preprocessed_img)
|
| 74 |
|
| 75 |
def preprocess(self, img):
|
| 76 |
image = self.image_processor.preprocess(img,
|
|
@@ -83,19 +84,31 @@ class Seagull():
|
|
| 83 |
align_corners=False).squeeze(0)
|
| 84 |
|
| 85 |
return image
|
| 86 |
-
|
| 87 |
-
def seagull_predict(self, img, mask, instruct_type):
|
| 88 |
-
image = self.preprocess(img)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
mask = np.array(mask, dtype=np.int)
|
|
|
|
| 91 |
ys, xs = np.where(mask > 0)
|
| 92 |
if len(xs) > 0 and len(ys) > 0:
|
| 93 |
-
# Find the minimal bounding rectangle for the entire mask
|
| 94 |
x_min, x_max = np.min(xs), np.max(xs)
|
| 95 |
y_min, y_max = np.min(ys), np.max(ys)
|
| 96 |
w1 = x_max - x_min
|
| 97 |
h1 = y_max - y_min
|
| 98 |
-
|
| 99 |
bounding_box = (x_min, y_min, w1, h1)
|
| 100 |
else:
|
| 101 |
bounding_box = None
|
|
@@ -104,7 +117,7 @@ class Seagull():
|
|
| 104 |
mask = np.array(mask > 0.1, dtype=np.uint8)
|
| 105 |
masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
|
| 106 |
|
| 107 |
-
input_ids = self.ids_input[instruct_type.lower()]
|
| 108 |
|
| 109 |
x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h
|
| 110 |
cropped_img = img[y1:y1 + h1, x1:x1 + w1]
|
|
@@ -127,8 +140,8 @@ class Seagull():
|
|
| 127 |
max_new_tokens=2048,
|
| 128 |
use_cache=True,
|
| 129 |
num_beams=1,
|
| 130 |
-
top_k = 0,
|
| 131 |
-
top_p = 1,
|
| 132 |
)
|
| 133 |
|
| 134 |
self.model.forward = self.model.orig_forward
|
|
|
|
| 13 |
import cv2
|
| 14 |
from typing import List
|
| 15 |
from PIL import Image
|
| 16 |
+
from pycocotools import mask as mask_utils
|
| 17 |
|
| 18 |
class Seagull():
|
| 19 |
def __init__(self, model_path, device='cuda'):
|
|
|
|
| 41 |
begin_str = "<image>\nThis provides an overview of the image.\n Please answer the following questions about the provided region. Note: Distortions include: blur, colorfulness, compression, contrast exposure and noise.\n Here is the region <global><local>. "
|
| 42 |
|
| 43 |
instruction = {
|
| 44 |
+
'distortion': 'Provide the distortion type of this region.',
|
| 45 |
+
'quality': 'Analyze the quality of this region.',
|
| 46 |
+
'importance': 'Consider the impact of this region on the overall image quality. Analyze its importance to the overall image quality.'
|
| 47 |
}
|
| 48 |
|
| 49 |
self.ids_input = {}
|
|
|
|
| 71 |
else:
|
| 72 |
preprocessed_img = img.copy()
|
| 73 |
|
| 74 |
+
return (preprocessed_img, preprocessed_img, preprocessed_img, preprocessed_img)
|
| 75 |
|
| 76 |
def preprocess(self, img):
|
| 77 |
image = self.image_processor.preprocess(img,
|
|
|
|
| 84 |
align_corners=False).squeeze(0)
|
| 85 |
|
| 86 |
return image
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
def seagull_predict(self, img, mask, instruct_type, mask_type='rle'):
|
| 89 |
+
if isinstance(img, str):
|
| 90 |
+
img = cv2.imread(img)
|
| 91 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 92 |
+
h, w, _ = img.shape
|
| 93 |
+
|
| 94 |
+
if mask_type == 'rle': # use the mask to indicate the roi
|
| 95 |
+
compressed_rle = {'size' : [h, w], 'counts' : mask}
|
| 96 |
+
mask = mask_utils.decode(compressed_rle)
|
| 97 |
+
elif mask_type == 'points': # use the point to indicate the roi
|
| 98 |
+
x_min, y_min, w1, h1 = mask
|
| 99 |
+
x_max, y_max = x_min + w1, y_min + h1
|
| 100 |
+
mask = np.zeros_like(img[:, :, 0])
|
| 101 |
+
mask[max(0, y_min):min(y_max, mask.shape[0]), max(0, x_min):min(x_max, mask.shape[1])] = 1
|
| 102 |
+
|
| 103 |
+
image = self.preprocess(img)
|
| 104 |
mask = np.array(mask, dtype=np.int)
|
| 105 |
+
|
| 106 |
ys, xs = np.where(mask > 0)
|
| 107 |
if len(xs) > 0 and len(ys) > 0:
|
|
|
|
| 108 |
x_min, x_max = np.min(xs), np.max(xs)
|
| 109 |
y_min, y_max = np.min(ys), np.max(ys)
|
| 110 |
w1 = x_max - x_min
|
| 111 |
h1 = y_max - y_min
|
|
|
|
| 112 |
bounding_box = (x_min, y_min, w1, h1)
|
| 113 |
else:
|
| 114 |
bounding_box = None
|
|
|
|
| 117 |
mask = np.array(mask > 0.1, dtype=np.uint8)
|
| 118 |
masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
|
| 119 |
|
| 120 |
+
input_ids = self.ids_input[instruct_type.split()[0].lower()]
|
| 121 |
|
| 122 |
x1, y1, w1, h1 = list(map(int, bounding_box)) # x y w h
|
| 123 |
cropped_img = img[y1:y1 + h1, x1:x1 + w1]
|
|
|
|
| 140 |
max_new_tokens=2048,
|
| 141 |
use_cache=True,
|
| 142 |
num_beams=1,
|
| 143 |
+
top_k = 0,
|
| 144 |
+
top_p = 1,
|
| 145 |
)
|
| 146 |
|
| 147 |
self.model.forward = self.model.orig_forward
|