import argparse import os import sys import subprocess import json import ffmpeg import pandas as pd import torch import torch.nn as nn from tqdm import tqdm from torchvision import transforms import clip from transformers import Blip2Processor, Blip2ForConditionalGeneration from extractor.extract_frag import VideoDataset_feature from extractor.extract_clip_embeds import extract_features_clip_embed from extractor.extract_slowfast_clip import SlowFast, extract_features_slowfast_pool from extractor.extract_swint_clip import SwinT, extract_features_swint_pool from model_finetune import fix_state_dict def get_transform(resize): return transforms.Compose([transforms.Resize([resize, resize]), transforms.ToTensor(), transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])]) def setup_device(config): if config.device == "gpu": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type == "cuda": torch.cuda.set_device(0) else: device = torch.device("cpu") print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}") return device def load_prompts(json_path): with open(json_path, "r", encoding="utf-8") as f: return json.load(f) def load_model(config, device, Mlp, input_features=13056): model = Mlp(input_features=input_features, out_features=1, drop_rate=0.1, act_layer=nn.GELU).to(device) if config.intra_cross_experiment == 'intra': if config.train_data_name == 'lsvq_train': if config.test_data_name == 'lsvq_test': model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_kfold.pth") elif config.test_data_name == 'lsvq_test_1080p': model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_1080p.pth") else: raise ValueError( "❌ Invalid dataset combination for intra-dataset experiment.\n" "👉 When using `intra` with `lsvq_train`, please select test dataset as `lsvq_test` or `lsvq_test_1080p`.\n" "If you want to test on another dataset, please switch to the `cross` experiment setting." ) else: model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model.pth") elif config.intra_cross_experiment == 'cross': if config.train_data_name == 'lsvq_train': if config.is_finetune: model_path = os.path.join(config.save_model_path, f"finetune/{config.test_data_name}_{config.network_name}_fine_tuned_model.pth") else: model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_kfold.pth") else: raise ValueError( "❌ Invalid training dataset for cross-dataset experiment.\n" "👉 The cross-dataset experiment supports `lsvq_train` as the only training dataset for fine-tuning models.\n" "Please set `Train Dataset` to `lsvq_train` to continue." ) print("Loading model from:", model_path) state_dict = torch.load(model_path, map_location=device) fixed_state_dict = fix_state_dict(state_dict) try: model.load_state_dict(fixed_state_dict) except RuntimeError as e: print(e) return model def evaluate_video_quality(preprocess_data, data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, model_mlp, device): # get video features model_slowfast.eval() model_swint.eval() clip_model.eval() blip_model.eval() with torch.no_grad(): for i, (video_segments, video_res_frag_all, video_frag_all, video_name, frames_info, metadata) in enumerate(tqdm(data_loader, desc="Processing Videos")): # slowfast features _, _, slowfast_frame_feats = extract_features_slowfast_pool(video_segments, model_slowfast, device) _, _, slowfast_res_frag_feats = extract_features_slowfast_pool(video_res_frag_all, model_slowfast, device) _, _, slowfast_frame_frag_feats = extract_features_slowfast_pool(video_frag_all, model_slowfast, device) slowfast_frame_feats_avg = slowfast_frame_feats.mean(dim=0) slowfast_res_frag_feats_avg = slowfast_res_frag_feats.mean(dim=0) slowfast_frame_frag_feats_avg = slowfast_frame_frag_feats.mean(dim=0) # swinT feature swint_frame_feats = extract_features_swint_pool(video_segments, model_swint, device) swint_res_frag_feats = extract_features_swint_pool(video_res_frag_all, model_swint, device) swint_frame_frag_feats = extract_features_swint_pool(video_frag_all, model_swint, device) swint_frame_feats_avg = swint_frame_feats.mean(dim=0) swint_res_frag_feats_avg = swint_res_frag_feats.mean(dim=0) swint_frame_frag_feats_avg = swint_frame_frag_feats.mean(dim=0) # semantic features image_embedding, quality_embedding, artifact_embedding = extract_features_clip_embed(frames_info, metadata, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device) image_embedding_avg = image_embedding.mean(dim=0) quality_embedding_avg = quality_embedding.mean(dim=0) artifact_embedding_avg = artifact_embedding.mean(dim=0) # frame + residual fragment + frame fragment features slowfast_features = torch.cat((slowfast_frame_feats_avg, slowfast_res_frag_feats_avg, slowfast_frame_frag_feats_avg), dim=0) swint_features = torch.cat((swint_frame_feats_avg, swint_res_frag_feats_avg, swint_frame_frag_feats_avg), dim=0) clip_features = torch.cat((image_embedding_avg, quality_embedding_avg, artifact_embedding_avg), dim=0) vqa_feats = torch.cat((slowfast_features, swint_features, clip_features), dim=0) vqa_feats = vqa_feats feature_tensor, _ = preprocess_data(vqa_feats, None) feature_tensor = feature_tensor.unsqueeze(0) if feature_tensor.dim() == 1 else feature_tensor model_mlp.eval() with torch.no_grad(): with torch.amp.autocast(device_type=device.type if device.type == 'cuda' else 'cpu'): prediction = model_mlp(feature_tensor) predicted_score = prediction.item() return predicted_score def parse_framerate(framerate_str): num, den = framerate_str.split('/') framerate = float(num)/float(den) return framerate def get_video_metadata(video_path): print(video_path) ffprobe_path = 'ffprobe' cmd = f'{ffprobe_path} -v error -select_streams v:0 -show_entries stream=width,height,nb_frames,r_frame_rate,bit_rate,bits_per_raw_sample,pix_fmt -of json {video_path}' try: result = subprocess.run(cmd, shell=True, capture_output=True, check=True) info = json.loads(result.stdout) except Exception as e: print(f"Error processing file {video_path}: {e}") return {} width = info['streams'][0]['width'] height = info['streams'][0]['height'] bitrate = info['streams'][0].get('bit_rate', 0) bitdepth = info['streams'][0].get('bits_per_raw_sample', 0) framerate = info['streams'][0]['r_frame_rate'] framerate = parse_framerate(framerate) return width, height, bitrate, bitdepth, framerate def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='gpu', help='cpu or gpu') parser.add_argument('--model_name', type=str, default='Mlp') parser.add_argument('--select_criteria', type=str, default='byrmse') parser.add_argument('--intra_cross_experiment', type=str, default='intra', help='intra or cross') parser.add_argument('--is_finetune', type=bool, default=True, help='True or False') parser.add_argument('--save_model_path', type=str, default='./model/') parser.add_argument('--prompt_path', type=str, default="./config/prompts.json") parser.add_argument('--train_data_name', type=str, default='finevd', help='Name of the training data') parser.add_argument('--test_data_name', type=str, default='finevd', help='Name of the testing data') parser.add_argument('--test_video_path', type=str, default='./ugc_original_videos/0_16_07_500001604801190-yase.mp4', help='demo test video') parser.add_argument('--prediction_mode', type=float, default=50, help='default for inference') parser.add_argument('--network_name', type=str, default='camp-vqa') parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--resize', type=int, default=224) parser.add_argument('--patch_size', type=int, default=16) parser.add_argument('--target_size', type=int, default=224) args = parser.parse_args() return args if __name__ == '__main__': config = parse_arguments() device = setup_device(config) prompts = load_prompts(config.prompt_path) # test demo video resize_transform = get_transform(config.resize) top_n = int(config.target_size /config. patch_size) * int(config.target_size / config.patch_size) width, height, bitrate, bitdepth, framerate = get_video_metadata(config.test_video_path) data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]], 'test_data_name': [config.test_data_name], 'test_video_path': [config.test_video_path], 'prediction_mode': [config.prediction_mode], 'width': [width], 'height': [height], 'bitrate': [bitrate], 'bitdepth': [bitdepth], 'framerate': [framerate]} videos_dir = os.path.dirname(config.test_video_path) test_df = pd.DataFrame(data) print(test_df.T) print(f"Experiment Setting: {config.intra_cross_experiment}, {config.train_data_name} -> {config.test_data_name}") if config.intra_cross_experiment == 'cross': if config.train_data_name == 'lsvq_train': print(f"Fine-tune: {config.is_finetune}") dataset = VideoDataset_feature(test_df, videos_dir, config.test_data_name, resize_transform, config.resize, config.patch_size, config.target_size, top_n) data_loader = torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, num_workers = min(config.num_workers, os.cpu_count() or 1), pin_memory = device.type == "cuda" ) print(f"Model: {config.network_name} | Dataset: {config.test_data_name} | Device: {device}") # load models to device model_slowfast = SlowFast().to(device) model_swint = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device) clip_model, clip_preprocess = clip.load("ViT-B/32", device=device) blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl", use_fast=True) blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device) input_features = 13056 if config.intra_cross_experiment == 'intra': if config.train_data_name == 'lsvq_train': from model_regression_lsvq import Mlp, preprocess_data else: from model_regression import Mlp, preprocess_data elif config.intra_cross_experiment == 'cross': from model_regression_lsvq import Mlp, preprocess_data model_mlp = load_model(config, device, Mlp, input_features) quality_prediction = evaluate_video_quality(preprocess_data, data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, model_mlp, device) print("Predicted Quality Score:", quality_prediction)