Spaces:
Runtime error
Runtime error
| import numpy as np | |
| # import gradio | |
| import torch | |
| from transformers import BertTokenizer | |
| import argparse | |
| import gradio as gr | |
| import time | |
| from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer | |
| from modules.modeling import BirdModel | |
| show_num = 9 | |
| max_words = 32 | |
| video_path_zh = "features/Chinese_batch_visual_output_list.npy" | |
| frame_path_zh = "features/Chinese_batch_frame_output_list.npy" | |
| video_fea_zh = np.load(video_path_zh) | |
| video_fea_zh = torch.from_numpy(video_fea_zh) | |
| frame_fea_zh = np.load(frame_path_zh) | |
| frame_fea_zh = torch.from_numpy(frame_fea_zh) | |
| video_path_en = "features/English_batch_visual_output_list.npy" | |
| frame_path_en = "features/English_batch_frame_output_list.npy" | |
| video_fea_en = np.load(video_path_en) | |
| video_fea_en = torch.from_numpy(video_fea_en) | |
| frame_fea_en = np.load(frame_path_en) | |
| frame_fea_en = torch.from_numpy(frame_fea_en) | |
| test_path = "test_list.txt" | |
| # video_dir = "test1500_400_400/" | |
| video_dir = "test1500/" | |
| with open(test_path, 'r', encoding='utf8') as f_list: | |
| lines = f_list.readlines() | |
| video_ids = [itm.strip() + ".mp4" for itm in lines] | |
| def get_videoname(idx): | |
| videoname = [] | |
| videopath = [] | |
| for i in idx: | |
| videoname.append(video_ids[i]) | |
| path = video_dir + video_ids[i] | |
| videopath.append(path) | |
| return videoname, videopath | |
| def get_text(caption, tokenizer): | |
| # tokenize word | |
| words = tokenizer.tokenize(caption) | |
| # add cls token | |
| words = ["<|startoftext|>"] + words | |
| total_length_with_CLS = max_words - 1 | |
| if len(words) > total_length_with_CLS: | |
| words = words[:total_length_with_CLS] | |
| # add end token | |
| words = words + ["<|endoftext|>"] | |
| # convert token to id according to the vocab | |
| input_ids = tokenizer.convert_tokens_to_ids(words) | |
| # add zeros for feature of the same length | |
| input_mask = [1] * len(input_ids) | |
| while len(input_ids) < max_words: | |
| input_ids.append(0) | |
| input_mask.append(0) | |
| # ensure the length of feature to be equal with max words | |
| assert len(input_ids) == max_words | |
| assert len(input_mask) == max_words | |
| pairs_text = np.array(input_ids).reshape(-1, max_words) | |
| pairs_text = torch.from_numpy(pairs_text) | |
| pairs_mask = np.array(input_mask).reshape(-1, max_words) | |
| pairs_mask = torch.from_numpy(pairs_mask) | |
| return pairs_text, pairs_mask | |
| def get_args(description='Retrieval Task'): | |
| parser = argparse.ArgumentParser(description=description) | |
| parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.") | |
| parser.add_argument("--do_train", action='store_true', help="Whether to run training.") | |
| parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") | |
| parser.add_argument("--do_params", action='store_true', help="text the params of the model.") | |
| parser.add_argument("--use_frame_fea", action='store_true', help="whether use frame feature matching text") | |
| parser.add_argument('--task', type=str, default="retrieval", choices=["retrieval_VT", "retrieval"], | |
| help="choose downstream task.") | |
| parser.add_argument('--dataset', type=str, default="bird", choices=["bird", "msrvtt", "vatex", "msvd"], | |
| help="choose dataset.") | |
| parser.add_argument('--num_thread_reader', type=int, default=1, help='') | |
| parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate') | |
| parser.add_argument('--text_lr', type=float, default=0.00001, help='text encoder learning rate') | |
| parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit') | |
| parser.add_argument('--batch_size', type=int, default=256, help='batch size') | |
| parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval') | |
| parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay') | |
| parser.add_argument('--weight_decay', type=float, default=0.2, help='Learning rate exp epoch decay') | |
| parser.add_argument('--n_display', type=int, default=100, help='Information display frequence') | |
| parser.add_argument('--seed', type=int, default=42, help='random seed') | |
| parser.add_argument('--max_words', type=int, default=32, help='') | |
| parser.add_argument('--max_frames', type=int, default=12, help='') | |
| parser.add_argument('--top_frames', type=int, default=3, help='') | |
| parser.add_argument('--frame_sample', type=str, default="uniform", choices=["uniform", "random", "uniform_random"], | |
| help='frame sample strategy') | |
| parser.add_argument('--frame_sample_len', type=str, default="fix", choices=["dynamic", "fix"], | |
| help='use dynamic frame length of fix frame length') | |
| parser.add_argument('--language', type=str, default="chinese", choices=["chinese", "english"], | |
| help='language for text encoder') | |
| parser.add_argument('--use_temp', action='store_true', help='whether to use temporal transformer') | |
| parser.add_argument("--logdir", default=None, type=str, required=False, help="log dir for tensorboardX writer") | |
| parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module") | |
| parser.add_argument("--pretrained_text", default="hfl/chinese-roberta-wwm-ext", type=str, required=False, help="pretrained_text") | |
| parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.") | |
| parser.add_argument("--warmup_proportion", default=0.1, type=float, | |
| help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.") | |
| parser.add_argument('--gradient_accumulation_steps', type=int, default=1, | |
| help="Number of updates steps to accumulate before performing a backward/update pass.") | |
| parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.") | |
| parser.add_argument("--cache_dir", default="", type=str, | |
| help="Where do you want to store the pre-trained models downloaded from s3") | |
| parser.add_argument('--enable_amp', action='store_true', help="whether to use pytorch amp") | |
| parser.add_argument("--world_size", default=0, type=int, help="distribted training") | |
| parser.add_argument("--local_rank", default=0, type=int, help="distribted training") | |
| parser.add_argument("--rank", default=0, type=int, help="distribted training") | |
| parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.') | |
| args = parser.parse_args() | |
| # Check paramenters | |
| args.do_eval = True | |
| args.use_frame_fea = True | |
| args.use_temp = True | |
| return args | |
| def init_model(language): | |
| time1 = time.time() | |
| args = get_args() | |
| args.language = language | |
| if language == "chinese": | |
| model_path = "models/Chinese_vatex.bin" | |
| tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") | |
| elif language == "english": | |
| model_path = "models/English_vatex.bin" | |
| tokenizer = ClipTokenizer() | |
| else: | |
| raise Exception("language should be Chinese or English!") | |
| model_state_dict = torch.load(model_path, map_location='cpu') | |
| cross_model = "cross-base" | |
| model = BirdModel.from_pretrained(cross_model, state_dict=model_state_dict, task_config=args) | |
| device = torch.device("cpu") | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| print("language={}".format(language)) | |
| print("init model time: {}".format(time.time() - time1)) | |
| print("device:{}".format(device)) | |
| return model, tokenizer | |
| model_zh, tokenizer_zh = init_model(language="chinese") | |
| model_en, tokenizer_en = init_model(language="english") | |
| def t2v_search_zh(text): | |
| with torch.no_grad(): | |
| time1 = time.time() | |
| text_ids, text_mask = get_text(text, tokenizer_zh) | |
| print("get_text time: {}".format(time.time() - time1)) | |
| time1 = time.time() | |
| text_fea_zh = model_zh.text_encoder(text_ids, text_mask) | |
| print("text_encoder time: {}".format(time.time() - time1)) | |
| # print("text_fea.shape:{}".format(text_fea.shape)) | |
| # print("video_fea.shape:{}".format(video_fea.shape)) | |
| # print("frame_fea.shape:{}".format(frame_fea.shape)) | |
| time1 = time.time() | |
| sim_video = model_zh.loose_similarity(text_fea_zh, video_fea_zh) | |
| # print("sim_video.shape:{}".format(sim_video.shape)) | |
| sim_frame = model_zh.loose_similarity(text_fea_zh, frame_fea_zh) | |
| # print("sim_frame.shape:{}".format(sim_frame.shape)) | |
| sim_frame = torch.topk(sim_frame, k=model_zh.top_frames, dim=1)[0] | |
| sim_frame = torch.mean(sim_frame, dim=1) | |
| sim = sim_video + sim_frame | |
| value, index = sim.topk(show_num, dim=0, largest=True, sorted=True) | |
| # value, index = sim_video.topk(show_num, dim=0, largest=True, sorted=True) | |
| print("calculate_similarity time: {}".format(time.time() - time1)) | |
| print("value:{}".format(value)) | |
| print("index:{}".format(index)) | |
| videoname, videopath = get_videoname(index) | |
| print("videoname:{}".format(videoname)) | |
| print("videopath:{}".format(videopath)) | |
| return videopath | |
| def t2v_search_en(text): | |
| with torch.no_grad(): | |
| time1 = time.time() | |
| text_ids, text_mask = get_text(text, tokenizer_en) | |
| print("get_text time: {}".format(time.time() - time1)) | |
| time1 = time.time() | |
| text_fea_en = model_en.text_encoder(text_ids, text_mask) | |
| print("text_encoder time: {}".format(time.time() - time1)) | |
| # print("text_fea.shape:{}".format(text_fea.shape)) | |
| # print("video_fea.shape:{}".format(video_fea.shape)) | |
| # print("frame_fea.shape:{}".format(frame_fea.shape)) | |
| time1 = time.time() | |
| sim_video = model_en.loose_similarity(text_fea_en, video_fea_en) | |
| # print("sim_video.shape:{}".format(sim_video.shape)) | |
| sim_frame = model_en.loose_similarity(text_fea_en, frame_fea_en) | |
| # print("sim_frame.shape:{}".format(sim_frame.shape)) | |
| sim_frame = torch.topk(sim_frame, k=model_en.top_frames, dim=1)[0] | |
| sim_frame = torch.mean(sim_frame, dim=1) | |
| sim = sim_video + sim_frame | |
| value, index = sim.topk(show_num, dim=0, largest=True, sorted=True) | |
| # value, index = sim_video.topk(show_num, dim=0, largest=True, sorted=True) | |
| print("calculate_similarity time: {}".format(time.time() - time1)) | |
| print("value:{}".format(value)) | |
| print("index:{}".format(index)) | |
| videoname, videopath = get_videoname(index) | |
| print("videoname:{}".format(videoname)) | |
| print("videopath:{}".format(videopath)) | |
| return videopath | |
| def hello_world(name): | |
| return "hello world, my name is " + name + "!" | |
| def search_demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# <div align='center'>HMMC中英文本-视频检索 \ | |
| <a style='font-size:18px;color: #000000' href='https://github.com/cheetah003/HMMC'> Github </div>") | |
| demo.title = "HMMC中英文本-视频检索" | |
| with gr.Tab("中文"): | |
| with gr.Column(variant="panel"): | |
| with gr.Row(variant="compact"): | |
| input_text = gr.Textbox( | |
| label="输入文本", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="请输入检索文本...", | |
| ).style( | |
| container=False, | |
| ) | |
| btn = gr.Button("搜索").style(full_width=False) | |
| with gr.Column(variant="panel", scale=2): | |
| with gr.Row(variant="compact"): | |
| videos_top = [gr.Video( | |
| format="mp4", label="视频 "+str(i+1), | |
| ).style(height=300, width=300) for i in range(3)] | |
| with gr.Column(variant="panel", scale=1): | |
| with gr.Row(variant="compact"): | |
| videos_rest = [gr.Video( | |
| format="mp4", label="视频 "+str(i+1), | |
| ).style(height=150, width=150) for i in range(3, show_num)] | |
| searched_videos = videos_top + videos_rest | |
| btn.click(t2v_search_zh, inputs=input_text, outputs=searched_videos) | |
| with gr.Tab("English"): | |
| with gr.Column(variant="panel"): | |
| with gr.Row(variant="compact"): | |
| input_text = gr.Textbox( | |
| label="input text", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Please input text to search...", | |
| ).style( | |
| container=False, | |
| ) | |
| btn = gr.Button("Search").style(full_width=False) | |
| with gr.Column(variant="panel", scale=2): | |
| with gr.Row(variant="compact"): | |
| videos_top = [gr.Video( | |
| format="mp4", label="video " + str(i+1), | |
| ).style(height=300, width=300) for i in range(3)] | |
| with gr.Column(variant="panel", scale=1): | |
| with gr.Row(variant="compact"): | |
| videos_rest = [gr.Video( | |
| format="mp4", label="video " + str(i+1), | |
| ).style(height=150, width=150) for i in range(3, show_num)] | |
| searched_videos = videos_top + videos_rest | |
| btn.click(t2v_search_en, inputs=input_text, outputs=searched_videos) | |
| demo.launch() | |
| if __name__ == '__main__': | |
| search_demo() | |
| # text = "两个男人正在随着音乐跳舞,他们正在努力做着macarena舞蹈的动作。" | |
| # t2v_search(text) | |