xinyiW915 commited on
Commit
a58c15b
·
verified ·
1 Parent(s): 78aa13c

Upload 35 files

Browse files
app.py CHANGED
@@ -69,7 +69,7 @@ def run_camp_vqa(video_path, intra_cross_experiment, is_finetune, train_data_nam
69
  global model_cache
70
  if not model_cache:
71
  print("⏳ Loading models into cache... please wait")
72
-
73
  model_cache["slowfast"] = SlowFast().to(device)
74
  model_cache["swint"] = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
75
 
@@ -125,14 +125,49 @@ def run_camp_vqa(video_path, intra_cross_experiment, is_finetune, train_data_nam
125
 
126
  def update_test_dataset(intra_cross_experiment, train_dataset):
127
  """
128
- if intra: hide test dataset and value = train dataset
 
 
129
  if cross: show test dataset dropdown
130
  """
131
  if intra_cross_experiment == "intra":
132
- msg = f" Intra-dataset experiment — test dataset is automatically set to **{train_dataset}**."
133
- return gr.update(value=train_dataset, visible=False), gr.update(value=msg, visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  else:
135
- return gr.update(visible=True), gr.update(value="", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  def toggle_finetune_visibility(intra_cross_experiment, train_dataset):
138
  """
@@ -148,8 +183,7 @@ with gr.Blocks() as demo:
148
  "You can try our test video: "
149
  "<a href='https://huggingface.co/spaces/xinyiW915/CAMP-VQA/blob/main/ugc_original_videos/0_16_07_500001604801190-yase.mp4' target='_blank'> demo video</a>. "
150
  "<br><br>"
151
- "⚙️ This demo is currently running on <strong>Hugging Face CPU Basic</strong>: 2 vCPU 16 GB RAM."
152
- # "⚙️ This demo is currently running on <strong>Hugging Face ZeroGPU Space</strong>: Dynamic resources (NVIDIA A100)."
153
  )
154
 
155
  with gr.Row():
@@ -170,7 +204,7 @@ with gr.Blocks() as demo:
170
  )
171
  test_dataset = gr.Dropdown(
172
  label="Test Dataset",
173
- choices=["lsvq_test", "lsvq_test_1080p", "cvd_2014", "konvid_1k", "live_vqc", "youtube_ugc", "finevd", "live_yt_gaming", "kvq"],
174
  value="finevd",
175
  visible=True
176
  )
@@ -216,4 +250,4 @@ with gr.Blocks() as demo:
216
  queue=True
217
  )
218
 
219
- demo.launch()
 
69
  global model_cache
70
  if not model_cache:
71
  print("⏳ Loading models into cache... please wait")
72
+
73
  model_cache["slowfast"] = SlowFast().to(device)
74
  model_cache["swint"] = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
75
 
 
125
 
126
  def update_test_dataset(intra_cross_experiment, train_dataset):
127
  """
128
+ if intra:
129
+ - if train_dataset == lsvq_train → allow to choose test dataset (lsvq_test, lsvq_test_1080p)
130
+ - else: hide test dataset and set to same as train
131
  if cross: show test dataset dropdown
132
  """
133
  if intra_cross_experiment == "intra":
134
+ if train_dataset == "lsvq_train":
135
+ msg = " Intra LSVQ setting — please select between lsvq_test or lsvq_test_1080p."
136
+ return (
137
+ gr.update(
138
+ choices=["lsvq_test", "lsvq_test_1080p"],
139
+ value="lsvq_test",
140
+ visible=True,
141
+ ),
142
+ gr.update(value=msg, visible=True),
143
+ )
144
+ else:
145
+ msg = f" Intra-dataset experiment — test dataset is automatically set to {train_dataset}."
146
+ return (
147
+ gr.update(value=train_dataset, visible=False),
148
+ gr.update(value=msg, visible=True),
149
+ )
150
  else:
151
+ # cross: show full test dataset list
152
+ return (
153
+ gr.update(
154
+ choices=[
155
+ "lsvq_train",
156
+ "lsvq_test",
157
+ "lsvq_test_1080p",
158
+ "cvd_2014",
159
+ "konvid_1k",
160
+ "live_vqc",
161
+ "youtube_ugc",
162
+ "finevd",
163
+ "live_yt_gaming",
164
+ "kvq",
165
+ ],
166
+ visible=True,
167
+ ),
168
+ gr.update(value="", visible=False),
169
+ )
170
+
171
 
172
  def toggle_finetune_visibility(intra_cross_experiment, train_dataset):
173
  """
 
183
  "You can try our test video: "
184
  "<a href='https://huggingface.co/spaces/xinyiW915/CAMP-VQA/blob/main/ugc_original_videos/0_16_07_500001604801190-yase.mp4' target='_blank'> demo video</a>. "
185
  "<br><br>"
186
+ "⚙️ This demo is currently running on <strong>Hugging Face ZeroGPU Space</strong>: Dynamic resources (NVIDIA A100)."
 
187
  )
188
 
189
  with gr.Row():
 
204
  )
205
  test_dataset = gr.Dropdown(
206
  label="Test Dataset",
207
+ choices=["lsvq_train", "lsvq_test", "lsvq_test_1080p", "cvd_2014", "konvid_1k", "live_vqc", "youtube_ugc", "finevd", "live_yt_gaming", "kvq"],
208
  value="finevd",
209
  visible=True
210
  )
 
250
  queue=True
251
  )
252
 
253
+ demo.launch(share=True)
demo_test.py CHANGED
@@ -1,219 +1,226 @@
1
- import argparse
2
- import os
3
- import sys
4
- import subprocess
5
- import json
6
- import ffmpeg
7
- import pandas as pd
8
- import torch
9
- import torch.nn as nn
10
- from tqdm import tqdm
11
- from torchvision import transforms
12
- import clip
13
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
14
-
15
- from extractor.extract_frag import VideoDataset_feature
16
- from extractor.extract_clip_embeds import extract_features_clip_embed
17
- from extractor.extract_slowfast_clip import SlowFast, extract_features_slowfast_pool
18
- from extractor.extract_swint_clip import SwinT, extract_features_swint_pool
19
- from model_finetune import fix_state_dict
20
-
21
-
22
- def get_transform(resize):
23
- return transforms.Compose([transforms.Resize([resize, resize]),
24
- transforms.ToTensor(),
25
- transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])])
26
-
27
- def setup_device(config):
28
- if config.device == "gpu":
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- if device.type == "cuda":
31
- torch.cuda.set_device(0)
32
- else:
33
- device = torch.device("cpu")
34
- print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
35
- return device
36
-
37
- def load_prompts(json_path):
38
- with open(json_path, "r", encoding="utf-8") as f:
39
- return json.load(f)
40
-
41
- def load_model(config, device, Mlp, input_features=13056):
42
- model = Mlp(input_features=input_features, out_features=1, drop_rate=0.1, act_layer=nn.GELU).to(device)
43
-
44
- if config.intra_cross_experiment == 'intra':
45
- if config.train_data_name == 'lsvq_train':
46
- if config.test_data_name == 'lsvq_test':
47
- model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_kfold.pth")
48
- elif config.test_data_name == 'lsvq_test_1080p':
49
- model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_1080p.pth")
50
- else:
51
- print("Please use a cross-dataset experiment setting for the lsvq_train model to test it on another dataset, please try using the input 'cross' for 'intra_cross_experiment'.")
52
- sys.exit(1)
53
- else:
54
- model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model.pth")
55
-
56
- elif config.intra_cross_experiment == 'cross':
57
- if config.train_data_name == 'lsvq_train':
58
- if config.is_finetune:
59
- model_path = os.path.join(config.save_model_path, f"finetune/{config.test_data_name}_{config.network_name}_fine_tuned_model.pth")
60
- else:
61
- model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_kfold.pth")
62
- else:
63
- print("Invalid training data name for cross-experiment. We provided the lsvq_train model for the cross-experiment, please try using the input 'lsvq_train' for 'train_data_name'.")
64
- sys.exit(1)
65
-
66
- print("Loading model from:", model_path)
67
- state_dict = torch.load(model_path, map_location=device)
68
- fixed_state_dict = fix_state_dict(state_dict)
69
- try:
70
- model.load_state_dict(fixed_state_dict)
71
- except RuntimeError as e:
72
- print(e)
73
- return model
74
-
75
- def evaluate_video_quality(preprocess_data, data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, model_mlp, device):
76
- # get video features
77
- model_slowfast.eval()
78
- model_swint.eval()
79
- clip_model.eval()
80
- blip_model.eval()
81
- with torch.no_grad():
82
- for i, (video_segments, video_res_frag_all, video_frag_all, video_name, frames_info, metadata) in enumerate(tqdm(data_loader, desc="Processing Videos")):
83
- # slowfast features
84
- _, _, slowfast_frame_feats = extract_features_slowfast_pool(video_segments, model_slowfast, device)
85
- _, _, slowfast_res_frag_feats = extract_features_slowfast_pool(video_res_frag_all, model_slowfast, device)
86
- _, _, slowfast_frame_frag_feats = extract_features_slowfast_pool(video_frag_all, model_slowfast, device)
87
- slowfast_frame_feats_avg = slowfast_frame_feats.mean(dim=0)
88
- slowfast_res_frag_feats_avg = slowfast_res_frag_feats.mean(dim=0)
89
- slowfast_frame_frag_feats_avg = slowfast_frame_frag_feats.mean(dim=0)
90
-
91
- # swinT feature
92
- swint_frame_feats = extract_features_swint_pool(video_segments, model_swint, device)
93
- swint_res_frag_feats = extract_features_swint_pool(video_res_frag_all, model_swint, device)
94
- swint_frame_frag_feats = extract_features_swint_pool(video_frag_all, model_swint, device)
95
- swint_frame_feats_avg = swint_frame_feats.mean(dim=0)
96
- swint_res_frag_feats_avg = swint_res_frag_feats.mean(dim=0)
97
- swint_frame_frag_feats_avg = swint_frame_frag_feats.mean(dim=0)
98
-
99
- # semantic features
100
- image_embedding, quality_embedding, artifact_embedding = extract_features_clip_embed(frames_info, metadata, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device)
101
- image_embedding_avg = image_embedding.mean(dim=0)
102
- quality_embedding_avg = quality_embedding.mean(dim=0)
103
- artifact_embedding_avg = artifact_embedding.mean(dim=0)
104
-
105
- # frame + residual fragment + frame fragment features
106
- slowfast_features = torch.cat((slowfast_frame_feats_avg, slowfast_res_frag_feats_avg, slowfast_frame_frag_feats_avg), dim=0)
107
- swint_features = torch.cat((swint_frame_feats_avg, swint_res_frag_feats_avg, swint_frame_frag_feats_avg), dim=0)
108
- clip_features = torch.cat((image_embedding_avg, quality_embedding_avg, artifact_embedding_avg), dim=0)
109
- vqa_feats = torch.cat((slowfast_features, swint_features, clip_features), dim=0)
110
-
111
- vqa_feats = vqa_feats
112
- feature_tensor, _ = preprocess_data(vqa_feats, None)
113
- feature_tensor = feature_tensor.unsqueeze(0) if feature_tensor.dim() == 1 else feature_tensor
114
-
115
- model_mlp.eval()
116
- with torch.no_grad():
117
- with torch.amp.autocast(device_type=device.type if device.type == 'cuda' else 'cpu'):
118
- prediction = model_mlp(feature_tensor)
119
- predicted_score = prediction.item()
120
- return predicted_score
121
-
122
- def parse_framerate(framerate_str):
123
- num, den = framerate_str.split('/')
124
- framerate = float(num)/float(den)
125
- return framerate
126
-
127
- def get_video_metadata(video_path):
128
- print(video_path)
129
- ffprobe_path = 'ffprobe'
130
- cmd = f'{ffprobe_path} -v error -select_streams v:0 -show_entries stream=width,height,nb_frames,r_frame_rate,bit_rate,bits_per_raw_sample,pix_fmt -of json {video_path}'
131
- try:
132
- result = subprocess.run(cmd, shell=True, capture_output=True, check=True)
133
- info = json.loads(result.stdout)
134
- except Exception as e:
135
- print(f"Error processing file {video_path}: {e}")
136
- return {}
137
-
138
- width = info['streams'][0]['width']
139
- height = info['streams'][0]['height']
140
- bitrate = info['streams'][0].get('bit_rate', 0)
141
- bitdepth = info['streams'][0].get('bits_per_raw_sample', 0)
142
- framerate = info['streams'][0]['r_frame_rate']
143
- framerate = parse_framerate(framerate)
144
- return width, height, bitrate, bitdepth, framerate
145
-
146
- def parse_arguments():
147
- parser = argparse.ArgumentParser()
148
- parser.add_argument('--device', type=str, default='gpu', help='cpu or gpu')
149
- parser.add_argument('--model_name', type=str, default='Mlp')
150
- parser.add_argument('--select_criteria', type=str, default='byrmse')
151
- parser.add_argument('--intra_cross_experiment', type=str, default='cross', help='intra or cross')
152
- parser.add_argument('--is_finetune', type=bool, default=True, help='True or False')
153
- parser.add_argument('--save_model_path', type=str, default='../model/')
154
- parser.add_argument('--prompt_path', type=str, default="./config/prompts.json")
155
-
156
- parser.add_argument('--train_data_name', type=str, default='lsvq_train', help='Name of the training data')
157
- parser.add_argument('--test_data_name', type=str, default='finevd', help='Name of the testing data')
158
- parser.add_argument('--test_video_path', type=str, default='../test_videos/0_16_07_500001604801190-yase.mp4', help='demo test video')
159
- parser.add_argument('--prediction_mode', type=float, default=50, help='default for inference')
160
-
161
- parser.add_argument('--network_name', type=str, default='camp-vqa')
162
- parser.add_argument('--num_workers', type=int, default=4)
163
- parser.add_argument('--resize', type=int, default=224)
164
- parser.add_argument('--patch_size', type=int, default=16)
165
- parser.add_argument('--target_size', type=int, default=224)
166
- args = parser.parse_args()
167
- return args
168
-
169
- if __name__ == '__main__':
170
- config = parse_arguments()
171
- device = setup_device(config)
172
- prompts = load_prompts(config.prompt_path)
173
-
174
- # test demo video
175
- resize_transform = get_transform(config.resize)
176
- top_n = int(config.target_size /config. patch_size) * int(config.target_size / config.patch_size)
177
-
178
- width, height, bitrate, bitdepth, framerate = get_video_metadata(config.test_video_path)
179
-
180
- data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]],
181
- 'test_data_name': [config.test_data_name],
182
- 'test_video_path': [config.test_video_path],
183
- 'prediction_mode': [config.prediction_mode],
184
- 'width': [width], 'height': [height], 'bitrate': [bitrate], 'bitdepth': [bitdepth], 'framerate': [framerate]}
185
- videos_dir = os.path.dirname(config.test_video_path)
186
- test_df = pd.DataFrame(data)
187
- print(test_df.T)
188
- print(f"Experiment Setting: {config.intra_cross_experiment}, {config.train_data_name} -> {config.test_data_name}")
189
- if config.intra_cross_experiment == 'cross':
190
- if config.train_data_name == 'lsvq_train':
191
- print(f"Fine-tune: {config.is_finetune}")
192
-
193
- dataset = VideoDataset_feature(test_df, videos_dir, config.test_data_name, resize_transform, config.resize, config.patch_size, config.target_size, top_n)
194
-
195
- data_loader = torch.utils.data.DataLoader(
196
- dataset, batch_size=1, shuffle=False, num_workers = min(config.num_workers, os.cpu_count() or 1), pin_memory = device.type == "cuda"
197
- )
198
- print(f"Model: {config.network_name} | Dataset: {config.test_data_name} | Device: {device}")
199
-
200
- # load models to device
201
- model_slowfast = SlowFast().to(device)
202
- model_swint = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
203
-
204
- clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
205
- blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl", use_fast=True)
206
- blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
207
-
208
- input_features = 13056
209
- if config.intra_cross_experiment == 'intra':
210
- if config.train_data_name == 'lsvq_train':
211
- from model_regression_lsvq import Mlp, preprocess_data
212
- else:
213
- from model_regression import Mlp, preprocess_data
214
- elif config.intra_cross_experiment == 'cross':
215
- from model_regression_lsvq import Mlp, preprocess_data
216
- model_mlp = load_model(config, device, Mlp, input_features)
217
-
218
- quality_prediction = evaluate_video_quality(preprocess_data, data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, model_mlp, device)
 
 
 
 
 
 
 
219
  print("Predicted Quality Score:", quality_prediction)
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import subprocess
5
+ import json
6
+ import ffmpeg
7
+ import pandas as pd
8
+ import torch
9
+ import torch.nn as nn
10
+ from tqdm import tqdm
11
+ from torchvision import transforms
12
+ import clip
13
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
14
+
15
+ from extractor.extract_frag import VideoDataset_feature
16
+ from extractor.extract_clip_embeds import extract_features_clip_embed
17
+ from extractor.extract_slowfast_clip import SlowFast, extract_features_slowfast_pool
18
+ from extractor.extract_swint_clip import SwinT, extract_features_swint_pool
19
+ from model_finetune import fix_state_dict
20
+
21
+
22
+ def get_transform(resize):
23
+ return transforms.Compose([transforms.Resize([resize, resize]),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])])
26
+
27
+ def setup_device(config):
28
+ if config.device == "gpu":
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ if device.type == "cuda":
31
+ torch.cuda.set_device(0)
32
+ else:
33
+ device = torch.device("cpu")
34
+ print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
35
+ return device
36
+
37
+ def load_prompts(json_path):
38
+ with open(json_path, "r", encoding="utf-8") as f:
39
+ return json.load(f)
40
+
41
+ def load_model(config, device, Mlp, input_features=13056):
42
+ model = Mlp(input_features=input_features, out_features=1, drop_rate=0.1, act_layer=nn.GELU).to(device)
43
+
44
+ if config.intra_cross_experiment == 'intra':
45
+ if config.train_data_name == 'lsvq_train':
46
+ if config.test_data_name == 'lsvq_test':
47
+ model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_kfold.pth")
48
+ elif config.test_data_name == 'lsvq_test_1080p':
49
+ model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_1080p.pth")
50
+ else:
51
+ raise ValueError(
52
+ "❌ Invalid dataset combination for intra-dataset experiment.\n"
53
+ "👉 When using `intra` with `lsvq_train`, please select test dataset as `lsvq_test` or `lsvq_test_1080p`.\n"
54
+ "If you want to test on another dataset, please switch to the `cross` experiment setting."
55
+ )
56
+
57
+ else:
58
+ model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model.pth")
59
+
60
+ elif config.intra_cross_experiment == 'cross':
61
+ if config.train_data_name == 'lsvq_train':
62
+ if config.is_finetune:
63
+ model_path = os.path.join(config.save_model_path, f"finetune/{config.test_data_name}_{config.network_name}_fine_tuned_model.pth")
64
+ else:
65
+ model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_kfold.pth")
66
+ else:
67
+ raise ValueError(
68
+ "❌ Invalid training dataset for cross-dataset experiment.\n"
69
+ "👉 The cross-dataset experiment supports `lsvq_train` as the only training dataset for fine-tuning models.\n"
70
+ "Please set `Train Dataset` to `lsvq_train` to continue."
71
+ )
72
+
73
+ print("Loading model from:", model_path)
74
+ state_dict = torch.load(model_path, map_location=device)
75
+ fixed_state_dict = fix_state_dict(state_dict)
76
+ try:
77
+ model.load_state_dict(fixed_state_dict)
78
+ except RuntimeError as e:
79
+ print(e)
80
+ return model
81
+
82
+ def evaluate_video_quality(preprocess_data, data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, model_mlp, device):
83
+ # get video features
84
+ model_slowfast.eval()
85
+ model_swint.eval()
86
+ clip_model.eval()
87
+ blip_model.eval()
88
+ with torch.no_grad():
89
+ for i, (video_segments, video_res_frag_all, video_frag_all, video_name, frames_info, metadata) in enumerate(tqdm(data_loader, desc="Processing Videos")):
90
+ # slowfast features
91
+ _, _, slowfast_frame_feats = extract_features_slowfast_pool(video_segments, model_slowfast, device)
92
+ _, _, slowfast_res_frag_feats = extract_features_slowfast_pool(video_res_frag_all, model_slowfast, device)
93
+ _, _, slowfast_frame_frag_feats = extract_features_slowfast_pool(video_frag_all, model_slowfast, device)
94
+ slowfast_frame_feats_avg = slowfast_frame_feats.mean(dim=0)
95
+ slowfast_res_frag_feats_avg = slowfast_res_frag_feats.mean(dim=0)
96
+ slowfast_frame_frag_feats_avg = slowfast_frame_frag_feats.mean(dim=0)
97
+
98
+ # swinT feature
99
+ swint_frame_feats = extract_features_swint_pool(video_segments, model_swint, device)
100
+ swint_res_frag_feats = extract_features_swint_pool(video_res_frag_all, model_swint, device)
101
+ swint_frame_frag_feats = extract_features_swint_pool(video_frag_all, model_swint, device)
102
+ swint_frame_feats_avg = swint_frame_feats.mean(dim=0)
103
+ swint_res_frag_feats_avg = swint_res_frag_feats.mean(dim=0)
104
+ swint_frame_frag_feats_avg = swint_frame_frag_feats.mean(dim=0)
105
+
106
+ # semantic features
107
+ image_embedding, quality_embedding, artifact_embedding = extract_features_clip_embed(frames_info, metadata, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device)
108
+ image_embedding_avg = image_embedding.mean(dim=0)
109
+ quality_embedding_avg = quality_embedding.mean(dim=0)
110
+ artifact_embedding_avg = artifact_embedding.mean(dim=0)
111
+
112
+ # frame + residual fragment + frame fragment features
113
+ slowfast_features = torch.cat((slowfast_frame_feats_avg, slowfast_res_frag_feats_avg, slowfast_frame_frag_feats_avg), dim=0)
114
+ swint_features = torch.cat((swint_frame_feats_avg, swint_res_frag_feats_avg, swint_frame_frag_feats_avg), dim=0)
115
+ clip_features = torch.cat((image_embedding_avg, quality_embedding_avg, artifact_embedding_avg), dim=0)
116
+ vqa_feats = torch.cat((slowfast_features, swint_features, clip_features), dim=0)
117
+
118
+ vqa_feats = vqa_feats
119
+ feature_tensor, _ = preprocess_data(vqa_feats, None)
120
+ feature_tensor = feature_tensor.unsqueeze(0) if feature_tensor.dim() == 1 else feature_tensor
121
+
122
+ model_mlp.eval()
123
+ with torch.no_grad():
124
+ with torch.amp.autocast(device_type=device.type if device.type == 'cuda' else 'cpu'):
125
+ prediction = model_mlp(feature_tensor)
126
+ predicted_score = prediction.item()
127
+ return predicted_score
128
+
129
+ def parse_framerate(framerate_str):
130
+ num, den = framerate_str.split('/')
131
+ framerate = float(num)/float(den)
132
+ return framerate
133
+
134
+ def get_video_metadata(video_path):
135
+ print(video_path)
136
+ ffprobe_path = 'ffprobe'
137
+ cmd = f'{ffprobe_path} -v error -select_streams v:0 -show_entries stream=width,height,nb_frames,r_frame_rate,bit_rate,bits_per_raw_sample,pix_fmt -of json {video_path}'
138
+ try:
139
+ result = subprocess.run(cmd, shell=True, capture_output=True, check=True)
140
+ info = json.loads(result.stdout)
141
+ except Exception as e:
142
+ print(f"Error processing file {video_path}: {e}")
143
+ return {}
144
+
145
+ width = info['streams'][0]['width']
146
+ height = info['streams'][0]['height']
147
+ bitrate = info['streams'][0].get('bit_rate', 0)
148
+ bitdepth = info['streams'][0].get('bits_per_raw_sample', 0)
149
+ framerate = info['streams'][0]['r_frame_rate']
150
+ framerate = parse_framerate(framerate)
151
+ return width, height, bitrate, bitdepth, framerate
152
+
153
+ def parse_arguments():
154
+ parser = argparse.ArgumentParser()
155
+ parser.add_argument('--device', type=str, default='gpu', help='cpu or gpu')
156
+ parser.add_argument('--model_name', type=str, default='Mlp')
157
+ parser.add_argument('--select_criteria', type=str, default='byrmse')
158
+ parser.add_argument('--intra_cross_experiment', type=str, default='intra', help='intra or cross')
159
+ parser.add_argument('--is_finetune', type=bool, default=True, help='True or False')
160
+ parser.add_argument('--save_model_path', type=str, default='./model/')
161
+ parser.add_argument('--prompt_path', type=str, default="./config/prompts.json")
162
+
163
+ parser.add_argument('--train_data_name', type=str, default='finevd', help='Name of the training data')
164
+ parser.add_argument('--test_data_name', type=str, default='finevd', help='Name of the testing data')
165
+ parser.add_argument('--test_video_path', type=str, default='./ugc_original_videos/0_16_07_500001604801190-yase.mp4', help='demo test video')
166
+ parser.add_argument('--prediction_mode', type=float, default=50, help='default for inference')
167
+
168
+ parser.add_argument('--network_name', type=str, default='camp-vqa')
169
+ parser.add_argument('--num_workers', type=int, default=4)
170
+ parser.add_argument('--resize', type=int, default=224)
171
+ parser.add_argument('--patch_size', type=int, default=16)
172
+ parser.add_argument('--target_size', type=int, default=224)
173
+ args = parser.parse_args()
174
+ return args
175
+
176
+ if __name__ == '__main__':
177
+ config = parse_arguments()
178
+ device = setup_device(config)
179
+ prompts = load_prompts(config.prompt_path)
180
+
181
+ # test demo video
182
+ resize_transform = get_transform(config.resize)
183
+ top_n = int(config.target_size /config. patch_size) * int(config.target_size / config.patch_size)
184
+
185
+ width, height, bitrate, bitdepth, framerate = get_video_metadata(config.test_video_path)
186
+
187
+ data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]],
188
+ 'test_data_name': [config.test_data_name],
189
+ 'test_video_path': [config.test_video_path],
190
+ 'prediction_mode': [config.prediction_mode],
191
+ 'width': [width], 'height': [height], 'bitrate': [bitrate], 'bitdepth': [bitdepth], 'framerate': [framerate]}
192
+ videos_dir = os.path.dirname(config.test_video_path)
193
+ test_df = pd.DataFrame(data)
194
+ print(test_df.T)
195
+ print(f"Experiment Setting: {config.intra_cross_experiment}, {config.train_data_name} -> {config.test_data_name}")
196
+ if config.intra_cross_experiment == 'cross':
197
+ if config.train_data_name == 'lsvq_train':
198
+ print(f"Fine-tune: {config.is_finetune}")
199
+
200
+ dataset = VideoDataset_feature(test_df, videos_dir, config.test_data_name, resize_transform, config.resize, config.patch_size, config.target_size, top_n)
201
+
202
+ data_loader = torch.utils.data.DataLoader(
203
+ dataset, batch_size=1, shuffle=False, num_workers = min(config.num_workers, os.cpu_count() or 1), pin_memory = device.type == "cuda"
204
+ )
205
+ print(f"Model: {config.network_name} | Dataset: {config.test_data_name} | Device: {device}")
206
+
207
+ # load models to device
208
+ model_slowfast = SlowFast().to(device)
209
+ model_swint = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
210
+
211
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
212
+ blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl", use_fast=True)
213
+ blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
214
+
215
+ input_features = 13056
216
+ if config.intra_cross_experiment == 'intra':
217
+ if config.train_data_name == 'lsvq_train':
218
+ from model_regression_lsvq import Mlp, preprocess_data
219
+ else:
220
+ from model_regression import Mlp, preprocess_data
221
+ elif config.intra_cross_experiment == 'cross':
222
+ from model_regression_lsvq import Mlp, preprocess_data
223
+ model_mlp = load_model(config, device, Mlp, input_features)
224
+
225
+ quality_prediction = evaluate_video_quality(preprocess_data, data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, model_mlp, device)
226
  print("Predicted Quality Score:", quality_prediction)
extractor/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (179 Bytes). View file
 
extractor/__pycache__/extract_frag.cpython-38.pyc ADDED
Binary file (9.22 kB). View file
 
extractor/__pycache__/extract_slowfast_clip.cpython-38.pyc ADDED
Binary file (2.26 kB). View file
 
extractor/__pycache__/extract_swint_clip.cpython-38.pyc ADDED
Binary file (1.75 kB). View file
 
model_regression.py CHANGED
@@ -110,7 +110,8 @@ def preprocess_data(X, y):
110
  X_min = X.min(dim=0, keepdim=True).values
111
  X_max = X.max(dim=0, keepdim=True).values
112
  X = (X - X_min) / (X_max - X_min)
113
- y = y.view(-1, 1).squeeze()
 
114
  return X, y
115
 
116
  # define 4-parameter logistic regression
 
110
  X_min = X.min(dim=0, keepdim=True).values
111
  X_max = X.max(dim=0, keepdim=True).values
112
  X = (X - X_min) / (X_max - X_min)
113
+ if y is not None:
114
+ y = y.view(-1, 1).squeeze()
115
  return X, y
116
 
117
  # define 4-parameter logistic regression