Spaces:
Runtime error
Runtime error
Update segment.py
Browse files- segment.py +8 -0
segment.py
CHANGED
|
@@ -11,6 +11,7 @@ import numpy as np
|
|
| 11 |
import argparse
|
| 12 |
import matplotlib
|
| 13 |
import gradio as gr
|
|
|
|
| 14 |
|
| 15 |
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
|
| 16 |
if type(image_path) is str:
|
|
@@ -52,6 +53,8 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
|
|
| 52 |
if torch.min(segmentation) == 0:
|
| 53 |
mask = segmentation==0
|
| 54 |
mask = mask.cpu().detach().numpy() # [512,512] bool
|
|
|
|
|
|
|
| 55 |
segment_label = "rest"
|
| 56 |
color = viridis(0)
|
| 57 |
label = f"{segment_label}-{0}"
|
|
@@ -65,6 +68,8 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
|
|
| 65 |
if torch.min(segmentation) != 0:
|
| 66 |
segment_id -= 1
|
| 67 |
mask = mask.cpu().detach().numpy() # [512,512] bool
|
|
|
|
|
|
|
| 68 |
mask_np_list.append(mask)
|
| 69 |
segment_label = model.config.id2label[segment['label_id']]
|
| 70 |
instances_counter[segment['label_id']] += 1
|
|
@@ -76,6 +81,9 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
|
|
| 76 |
label_list.append(label)
|
| 77 |
else:
|
| 78 |
mask = np.full(segmentation.shape, True)
|
|
|
|
|
|
|
|
|
|
| 79 |
segment_label = "all"
|
| 80 |
mask_np_list.append(mask)
|
| 81 |
color = viridis(0)
|
|
|
|
| 11 |
import argparse
|
| 12 |
import matplotlib
|
| 13 |
import gradio as gr
|
| 14 |
+
import cv2
|
| 15 |
|
| 16 |
def load_image(image_path, left=0, right=0, top=0, bottom=0, size = 512):
|
| 17 |
if type(image_path) is str:
|
|
|
|
| 53 |
if torch.min(segmentation) == 0:
|
| 54 |
mask = segmentation==0
|
| 55 |
mask = mask.cpu().detach().numpy() # [512,512] bool
|
| 56 |
+
print(mask.shape)
|
| 57 |
+
mask = cv2.resize(mask,(512,512))
|
| 58 |
segment_label = "rest"
|
| 59 |
color = viridis(0)
|
| 60 |
label = f"{segment_label}-{0}"
|
|
|
|
| 68 |
if torch.min(segmentation) != 0:
|
| 69 |
segment_id -= 1
|
| 70 |
mask = mask.cpu().detach().numpy() # [512,512] bool
|
| 71 |
+
print(mask.shape)
|
| 72 |
+
mask = cv2.resize(mask,(512,512))
|
| 73 |
mask_np_list.append(mask)
|
| 74 |
segment_label = model.config.id2label[segment['label_id']]
|
| 75 |
instances_counter[segment['label_id']] += 1
|
|
|
|
| 81 |
label_list.append(label)
|
| 82 |
else:
|
| 83 |
mask = np.full(segmentation.shape, True)
|
| 84 |
+
print(mask.shape)
|
| 85 |
+
mask = cv2.resize(mask,(512,512))
|
| 86 |
+
|
| 87 |
segment_label = "all"
|
| 88 |
mask_np_list.append(mask)
|
| 89 |
color = viridis(0)
|