| | import numpy as np |
| | from PIL import Image |
| | import axengine as ort |
| | import torch |
| | from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize, CenterCrop |
| | from tokenizer import SimpleTokenizer |
| | import argparse |
| |
|
| | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
| |
|
| |
|
| | def image_transform_v2(): |
| | resolution = 256 |
| | resize_size = resolution |
| | centercrop_size = resolution |
| | mean = OPENAI_DATASET_MEAN |
| | std = OPENAI_DATASET_STD |
| | aug_list = [ |
| | Resize( |
| | resize_size, |
| | interpolation=InterpolationMode.BICUBIC, |
| | ), |
| | CenterCrop(centercrop_size), |
| | ToTensor(), |
| | Normalize(mean=mean, std=std) |
| | ] |
| | preprocess = Compose(aug_list) |
| | return preprocess |
| | |
| | |
| | def image_transform_v1(): |
| | resolution = 256 |
| | resize_size = resolution |
| | centercrop_size = resolution |
| | aug_list = [ |
| | Resize( |
| | resize_size, |
| | interpolation=InterpolationMode.BILINEAR, |
| | ), |
| | CenterCrop(centercrop_size), |
| | ToTensor(), |
| | ] |
| | preprocess = Compose(aug_list) |
| | return preprocess |
| |
|
| |
|
| | def softmax(x, axis=-1): |
| | """ |
| | 对 numpy 数组在指定维度上应用 softmax 函数 |
| | |
| | 参数: |
| | x: numpy 数组,输入数据 |
| | axis: 计算 softmax 的维度,默认为最后一个维度 (-1) |
| | |
| | 返回: |
| | 经过 softmax 处理的 numpy 数组,与输入形状相同 |
| | """ |
| | |
| | e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) |
| | |
| | return e_x / np.sum(e_x, axis=axis, keepdims=True) |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("-ie", "--image_encoder_path", type=str, default="./mobileclip2_s4_image_encoder.axmodel", |
| | help="image encoder axmodel path") |
| | parser.add_argument("-te", "--text_encoder_path", type=str, default="./mobileclip2_s4_text_encoder.axmodel", |
| | help="text encoder axmodel path") |
| | parser.add_argument("-i", "--image", type=str, default="./zebra.jpg", |
| | help="input image path") |
| | parser.add_argument("-t", "--class_text", type=str, nargs='+', default=["a zebra", "a dog", "two zebras"], |
| | help='List of captions, e.g.: "a zebra" "a dog" "two zebras"') |
| | args = parser.parse_args() |
| |
|
| | image_encoder_path = args.image_encoder_path |
| | text_encoder_path = args.text_encoder_path |
| | |
| | preprocess = image_transform_v1() |
| | tokenizer = SimpleTokenizer(context_length=77) |
| |
|
| | image = preprocess(Image.open(args.image).convert('RGB')).unsqueeze(0) |
| | text = tokenizer(args.class_text) |
| | text = text.to(torch.int32) |
| |
|
| | onnx_image_encoder = ort.InferenceSession(image_encoder_path) |
| | onnx_text_encoder = ort.InferenceSession(text_encoder_path) |
| |
|
| | image_features = onnx_image_encoder.run(["unnorm_image_features"],{"image":np.array(image)})[0] |
| | |
| | |
| | |
| | |
| | |
| | text_features = onnx_text_encoder.run(["unnorm_text_features"], {"text": text.numpy()})[0] |
| | image_features /= np.linalg.norm(image_features, ord=2, axis=-1, keepdims=True) |
| | text_features /= np.linalg.norm(text_features, ord=2, axis=-1, keepdims=True) |
| |
|
| | text_probs = softmax(100.0 * image_features @ text_features.T) |
| |
|
| | print("Label probs:", text_probs) |