Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import vision_transformer as models | |
| import cv2 | |
| from torch import nn | |
| from utils import load_pretrained_weights | |
| class PatchEmbedding: | |
| """ | |
| 该类加载了预训练的VIT_Base模型,可以对输入图像生成图像的patch token。 | |
| Args: | |
| pretrained_weights (str): 预训练权重文件的路径。 | |
| arch (str, optional): 模型使用的体系结构。默认为“vit_base”。 | |
| patch_size (int, optional): 图像中提取的patch的大小。默认值为16。 | |
| Attributes: | |
| model: 图像嵌入模型。 | |
| embed_dim (int): 图像嵌入的维度。 | |
| Methods: | |
| load_pretrained_weights(pretrained_weights): 载入预训练的权重到模型中。 | |
| get_representations(image_path, tfms, denormalize): 为输入图像生成patch token。 | |
| """ | |
| def __init__(self, pretrained_weights, arch='vit_base', patch_size=16): | |
| self.model = models.__dict__[arch](patch_size=patch_size, num_classes=0) | |
| self.embed_dim = self.model.embed_dim | |
| self.model.eval().requires_grad_(False) | |
| self.load_pretrained_weights(pretrained_weights) | |
| from torchvision import transforms | |
| IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
| self.tfms = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), | |
| ]) | |
| def load_pretrained_weights(self, pretrained_weights): | |
| load_pretrained_weights(self.model, pretrained_weights) | |
| def get_representation(self, image): | |
| """ | |
| 生成输入图像的patch token。 | |
| Args: | |
| image_path (str): 输入图像的路径。 | |
| Returns: | |
| patch_tokens (ndarray): 表示生成的patch token的数组: N, C。 | |
| """ | |
| img = self.tfms(image) | |
| x = img[None,:] | |
| tokens = self.model.forward_features(x)[0] # N - 1, C | |
| tokens = nn.functional.normalize(tokens, dim=-1, p=2).numpy() | |
| cls_token = tokens[0] # C | |
| patch_tokens = tokens[1:] # N - 1, C | |
| return cls_token, patch_tokens | |
| def __call__(self, x): | |
| return self.model.forward_features(x) | |
| default_shape = (224,224) | |
| embedding = PatchEmbedding('weights/mmc.pth') | |
| def classify(query_image, support_image): | |
| # Your classification code here | |
| q_cls = embedding.get_representation(query_image)[0] | |
| s_cls = embedding.get_representation(support_image)[0] | |
| sim = (q_cls*s_cls).sum()*100 | |
| return f"{sim:.2f}" | |
| def segment(threshold, input): | |
| # Your segmentation code here | |
| image = input['image'] | |
| mask = input['mask'] | |
| patch_tokens = embedding.get_representation(image)[1] | |
| select = (cv2.resize(mask[:,:,0],(14,14))>0).flatten() | |
| q_pat = patch_tokens[select].mean(0) # C | |
| sim = patch_tokens @ q_pat[:,None] # N,1 | |
| mask = (sim.reshape(14,14) > threshold).astype('float') | |
| mask = cv2.resize(mask,(224,224)) | |
| ans = image*mask[:,:,None] | |
| return ans.astype('uint8') | |
| classification_tab = gr.Interface( | |
| fn=classify, | |
| inputs=[ | |
| # gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8), | |
| gr.inputs.Image(label="Query Image",shape=default_shape), | |
| gr.inputs.Image(label="Support Image",shape=default_shape) | |
| ], | |
| outputs=gr.outputs.Textbox(label="Prediction"), | |
| title="Classification" | |
| ) | |
| segmentation_tab = gr.Interface( | |
| fn=segment, | |
| inputs=[ | |
| gr.inputs.Slider(0, 1, step=0.01, label="Threshold",default=0.8), | |
| gr.inputs.Image(label="Input Image",tool="sketch",shape=default_shape) | |
| ], | |
| outputs=gr.outputs.Image('numpy',label='Segmentation'), | |
| title="Segmentation" | |
| ) | |
| with gr.Blocks() as app: | |
| gr.Markdown(""" | |
| @misc{wu2023masked, | |
| title={Masked Momentum Contrastive Learning for Zero-shot Semantic Understanding}, | |
| author={Jiantao Wu and Shentong Mo and Muhammad Awais and Sara Atito and Zhenhua Feng and Josef Kittler}, | |
| year={2023}, | |
| eprint={2308.11448}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.CV} | |
| }""") | |
| interface = gr.TabbedInterface( | |
| [classification_tab, segmentation_tab], | |
| ["Classification", "Segmentation"] | |
| # layout="horizontal" | |
| ) | |
| app.launch() |