xinyiW915 commited on
Commit
7509a87
·
verified ·
1 Parent(s): 4606046

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +176 -0
  2. demo_test.py +220 -0
  3. model_finetune.py +326 -0
  4. model_regression.py +656 -0
  5. model_regression_lsvq.py +666 -0
  6. requirements.txt +62 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from spaces import GPU
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import pandas as pd
6
+ from types import SimpleNamespace
7
+ import clip
8
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
9
+
10
+ from extractor.extract_frag import VideoDataset_feature
11
+ from extractor.extract_slowfast_clip import SlowFast
12
+ from extractor.extract_swint_clip import SwinT
13
+ from demo_test import get_transform, load_prompts, get_video_metadata, load_model, evaluate_video_quality
14
+
15
+
16
+ model_cache = {}
17
+
18
+ @GPU
19
+ def run_camp_vqa(video_path, intra_cross_experiment, is_finetune, train_data_name, test_data_name, network_name):
20
+ if not os.path.exists(video_path):
21
+ return "❌ No video uploaded or the uploaded file has expired. Please upload again."
22
+
23
+ print("CUDA available:", torch.cuda.is_available())
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ if device.type == "cuda":
26
+ print("Current device:", torch.cuda.current_device())
27
+ else:
28
+ print("Running on CPU")
29
+
30
+ config = SimpleNamespace(**{
31
+ 'model_name': 'Mlp',
32
+ 'select_criteria': 'byrmse',
33
+ 'intra_cross_experiment': intra_cross_experiment,
34
+ 'is_finetune': is_finetune,
35
+ 'save_model_path': 'model/',
36
+ 'prompt_path': './config/prompts.json',
37
+ 'train_data_name': train_data_name,
38
+ 'test_data_name': test_data_name,
39
+ 'test_video_path': video_path,
40
+ 'prediction_mode': 50,
41
+ 'network_name': network_name,
42
+ 'num_workers': 2,
43
+ 'resize': 224,
44
+ 'patch_size': 16,
45
+ 'target_size': 224,
46
+ })
47
+ print(f"Test video path: {config.test_video_path}")
48
+
49
+ # test demo video
50
+ resize_transform = get_transform(config.resize)
51
+ top_n = int(config.target_size / config.patch_size) ** 2
52
+
53
+ width, height, bitrate, bitdepth, framerate = get_video_metadata(config.test_video_path)
54
+
55
+ data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]],
56
+ 'test_data_name': [config.test_data_name],
57
+ 'test_video_path': [config.test_video_path],
58
+ 'prediction_mode': [config.prediction_mode],
59
+ 'width': [width], 'height': [height], 'bitrate': [bitrate], 'bitdepth': [bitdepth], 'framerate': [framerate]}
60
+ videos_dir = os.path.dirname(config.test_video_path)
61
+ test_df = pd.DataFrame(data)
62
+ print(test_df.T)
63
+ print(f"Experiment Setting: {config.intra_cross_experiment}, {config.train_data_name} -> {config.test_data_name}")
64
+ if config.intra_cross_experiment == 'cross':
65
+ if config.train_data_name == 'lsvq_train':
66
+ print(f"Fine-tune: {config.is_finetune}")
67
+
68
+ # load models to device
69
+ global model_cache
70
+ if not model_cache:
71
+ print("Loading models into cache (first time)...")
72
+ model_cache["slowfast"] = SlowFast().to(device)
73
+ model_cache["swint"] = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
74
+
75
+ model_cache["clip"], model_cache["clip_preprocess"] = clip.load("ViT-B/32", device=device)
76
+ model_cache["blip_processor"] = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl", use_fast=True)
77
+ model_cache["blip_model"] = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
78
+ print("Model cache initialized.")
79
+
80
+ # get model from cache
81
+ model_slowfast = model_cache["slowfast"]
82
+ model_swint = model_cache["swint"]
83
+ clip_model, clip_preprocess = model_cache["clip"], model_cache["clip_preprocess"]
84
+ blip_processor = model_cache["blip_processor"]
85
+ blip_model = model_cache["blip_model"]
86
+
87
+ input_features = 13056
88
+ if config.intra_cross_experiment == 'intra':
89
+ if config.train_data_name == 'lsvq_train':
90
+ from model_regression_lsvq import Mlp, preprocess_data
91
+ else:
92
+ from model_regression import Mlp, preprocess_data
93
+ elif config.intra_cross_experiment == 'cross':
94
+ from model_regression_lsvq import Mlp, preprocess_data
95
+ model_mlp = load_model(config, device, Mlp, input_features)
96
+ prompts = load_prompts(config.prompt_path)
97
+
98
+ dataset = VideoDataset_feature(test_df, videos_dir, config.test_data_name, resize_transform, config.resize, config.patch_size, config.target_size, top_n)
99
+ data_loader = torch.utils.data.DataLoader(
100
+ dataset, batch_size=1, shuffle=False, num_workers=min(config.num_workers, os.cpu_count() or 1), pin_memory=(device.type == "cuda")
101
+ )
102
+
103
+ try:
104
+ score = evaluate_video_quality(
105
+ preprocess_data,
106
+ data_loader,
107
+ model_slowfast,
108
+ model_swint,
109
+ clip_model,
110
+ clip_preprocess,
111
+ blip_processor,
112
+ blip_model,
113
+ prompts,
114
+ model_mlp,
115
+ device
116
+ )
117
+ return f"**Predicted Perceptual Quality Score:** {score:.4f} / 100"
118
+
119
+ except Exception as e:
120
+ return f"❌ Error: {str(e)}"
121
+ finally:
122
+ if "gradio" in video_path and os.path.exists(video_path):
123
+ os.remove(video_path)
124
+
125
+ def toggle_dataset_visibility(is_finetune):
126
+ return gr.update(visible=is_finetune)
127
+
128
+ with gr.Blocks() as demo:
129
+ gr.Markdown("# 📹 CAMP-VQA Online Demo")
130
+ gr.Markdown(
131
+ "Upload a short video and get its perceptual quality score predicted by CAMP-VQA."
132
+ "You can try our test video"
133
+ "<a href='https://huggingface.co/spaces/xinyiW915/CAMP-VQA/blob/main/ugc_original_videos/0_16_07_500001604801190-yase.mp4' target='_blank'>demo video</a>. "
134
+ "<br><br>"
135
+ # "⚙️ This demo is currently running on <strong>Hugging Face CPU Basic</strong>: 2 vCPU • 16 GB RAM."
136
+ "⚙️ This demo is currently running on <strong>Hugging Face ZeroGPU Space</strong>: Dynamic resources (NVIDIA A100)."
137
+ )
138
+
139
+ with gr.Row():
140
+ with gr.Column(scale=2):
141
+ video_input = gr.Video(label="Upload a Video (e.g. .mp4)")
142
+ intra_cross_experiment = gr.Dropdown(
143
+ label="Intra or Cross experiment",
144
+ choices=["intra", "cross"],
145
+ value="cross"
146
+ )
147
+ is_finetune_checkbox = gr.Checkbox(label="Use Finetuning?", value=False)
148
+ train_dataset = gr.Dropdown(
149
+ label="Train Dataset",
150
+ choices=["lsvq_train", "cvd_2014", "konvid_1k", "live_vqc", "youtube_ugc", "finevd", "live_yt_gaming", "kvq"],
151
+ value="lsvq_train"
152
+ )
153
+ test_dataset = gr.Dropdown(
154
+ label="Test Dataset for Finetuning",
155
+ choices=["lsvq_test", "lsvq_test_1080p", "cvd_2014", "konvid_1k", "live_vqc", "youtube_ugc", "finevd", "live_yt_gaming", "kvq"],
156
+ value="finevd"
157
+ )
158
+ model_dropdown = gr.Dropdown(
159
+ label="Our Models",
160
+ choices=["camp-vqa"],
161
+ value="camp-vqa"
162
+ )
163
+ run_button = gr.Button("Run Prediction")
164
+
165
+ with gr.Column(scale=1):
166
+ output_box = gr.Textbox(label="Predicted Quality Score (0–100)", lines=5)
167
+
168
+ run_button.click(
169
+ fn=run_camp_vqa,
170
+ inputs=[video_input, intra_cross_experiment, is_finetune_checkbox, train_dataset, test_dataset, model_dropdown],
171
+ outputs=output_box,
172
+ api_name="run",
173
+ queue=True
174
+ )
175
+
176
+ demo.launch()
demo_test.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import subprocess
5
+ import json
6
+ import ffmpeg
7
+ import pandas as pd
8
+ import torch
9
+ import torch.nn as nn
10
+ from tqdm import tqdm
11
+ from torchvision import transforms
12
+ import clip
13
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
14
+
15
+ from extractor.extract_frag import VideoDataset_feature
16
+ from extractor.extract_clip_embeds import extract_features_clip_embed
17
+ from extractor.extract_slowfast_clip import SlowFast, extract_features_slowfast_pool
18
+ from extractor.extract_swint_clip import SwinT, extract_features_swint_pool
19
+ from model_finetune import fix_state_dict
20
+
21
+
22
+ def get_transform(resize):
23
+ return transforms.Compose([transforms.Resize([resize, resize]),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])])
26
+
27
+ def setup_device(config):
28
+ if config.device == "gpu":
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ if device.type == "cuda":
31
+ torch.cuda.set_device(0)
32
+ else:
33
+ device = torch.device("cpu")
34
+ print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
35
+ return device
36
+
37
+ def load_prompts(json_path):
38
+ with open(json_path, "r", encoding="utf-8") as f:
39
+ return json.load(f)
40
+
41
+ def load_model(config, device, Mlp, input_features=13056):
42
+ model = Mlp(input_features=input_features, out_features=1, drop_rate=0.1, act_layer=nn.GELU).to(device)
43
+
44
+ if config.intra_cross_experiment == 'intra':
45
+ if config.train_data_name == 'lsvq_train':
46
+ if config.test_data_name == 'lsvq_test':
47
+ model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_kfold.pth")
48
+ elif config.test_data_name == 'lsvq_test_1080p':
49
+ model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_1080p.pth")
50
+ else:
51
+ print("Please use a cross-dataset experiment setting for the lsvq_train model to test it on another dataset, please try using the input 'cross' for 'intra_cross_experiment'.")
52
+ sys.exit(1)
53
+ else:
54
+ model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model.pth")
55
+
56
+ elif config.intra_cross_experiment == 'cross':
57
+ if config.train_data_name == 'lsvq_train':
58
+ if config.is_finetune:
59
+ model_path = os.path.join(config.save_model_path, f"finetune/{config.test_data_name}_{config.network_name}_fine_tuned_model.pth")
60
+ else:
61
+ model_path = os.path.join(config.save_model_path, f"wo_finetune/{config.train_data_name}_{config.network_name}_{config.model_name}_{config.select_criteria}_trained_model_kfold.pth")
62
+ else:
63
+ print("Invalid training data name for cross-experiment. We provided the lsvq_train model for the cross-experiment, please try using the input 'lsvq_train' for 'train_data_name'.")
64
+ sys.exit(1)
65
+
66
+ print("Loading model from:", model_path)
67
+ state_dict = torch.load(model_path, map_location=device)
68
+ fixed_state_dict = fix_state_dict(state_dict)
69
+ try:
70
+ model.load_state_dict(fixed_state_dict)
71
+ except RuntimeError as e:
72
+ print(e)
73
+ return model
74
+
75
+ def evaluate_video_quality(preprocess_data, data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, model_mlp, device):
76
+ # get video features
77
+ model_slowfast.eval()
78
+ model_swint.eval()
79
+ clip_model.eval()
80
+ blip_model.eval()
81
+ with torch.no_grad():
82
+ for i, (video_segments, video_res_frag_all, video_frag_all, video_name, frames_info, metadata) in enumerate(tqdm(data_loader, desc="Processing Videos")):
83
+ # slowfast features
84
+ _, _, slowfast_frame_feats = extract_features_slowfast_pool(video_segments, model_slowfast, device)
85
+ _, _, slowfast_res_frag_feats = extract_features_slowfast_pool(video_res_frag_all, model_slowfast, device)
86
+ _, _, slowfast_frame_frag_feats = extract_features_slowfast_pool(video_frag_all, model_slowfast, device)
87
+ slowfast_frame_feats_avg = slowfast_frame_feats.mean(dim=0)
88
+ slowfast_res_frag_feats_avg = slowfast_res_frag_feats.mean(dim=0)
89
+ slowfast_frame_frag_feats_avg = slowfast_frame_frag_feats.mean(dim=0)
90
+
91
+ # swinT feature
92
+ swint_frame_feats = extract_features_swint_pool(video_segments, model_swint, device)
93
+ swint_res_frag_feats = extract_features_swint_pool(video_res_frag_all, model_swint, device)
94
+ swint_frame_frag_feats = extract_features_swint_pool(video_frag_all, model_swint, device)
95
+ swint_frame_feats_avg = swint_frame_feats.mean(dim=0)
96
+ swint_res_frag_feats_avg = swint_res_frag_feats.mean(dim=0)
97
+ swint_frame_frag_feats_avg = swint_frame_frag_feats.mean(dim=0)
98
+
99
+ # semantic features
100
+ image_embedding, quality_embedding, artifact_embedding = extract_features_clip_embed(frames_info, metadata, clip_model, clip_preprocess, blip_processor, blip_model, prompts, device)
101
+ image_embedding_avg = image_embedding.mean(dim=0)
102
+ quality_embedding_avg = quality_embedding.mean(dim=0)
103
+ artifact_embedding_avg = artifact_embedding.mean(dim=0)
104
+
105
+ # frame + residual fragment + frame fragment features
106
+ slowfast_features = torch.cat((slowfast_frame_feats_avg, slowfast_res_frag_feats_avg, slowfast_frame_frag_feats_avg), dim=0)
107
+ swint_features = torch.cat((swint_frame_feats_avg, swint_res_frag_feats_avg, swint_frame_frag_feats_avg), dim=0)
108
+ clip_features = torch.cat((image_embedding_avg, quality_embedding_avg, artifact_embedding_avg), dim=0)
109
+ vqa_feats = torch.cat((slowfast_features, swint_features, clip_features), dim=0)
110
+
111
+ vqa_feats = vqa_feats
112
+ feature_tensor, _ = preprocess_data(vqa_feats, None)
113
+ feature_tensor = feature_tensor.unsqueeze(0) if feature_tensor.dim() == 1 else feature_tensor
114
+ print(f"Feature tensor shape before MLP: {feature_tensor.shape}")
115
+
116
+ model_mlp.eval()
117
+ with torch.no_grad():
118
+ with torch.amp.autocast(device_type=device.type if device.type == 'cuda' else 'cpu'):
119
+ prediction = model_mlp(feature_tensor)
120
+ predicted_score = prediction.item()
121
+ return predicted_score
122
+
123
+ def parse_framerate(framerate_str):
124
+ num, den = framerate_str.split('/')
125
+ framerate = float(num)/float(den)
126
+ return framerate
127
+
128
+ def get_video_metadata(video_path):
129
+ print(video_path)
130
+ ffprobe_path = 'ffprobe'
131
+ cmd = f'{ffprobe_path} -v error -select_streams v:0 -show_entries stream=width,height,nb_frames,r_frame_rate,bit_rate,bits_per_raw_sample,pix_fmt -of json {video_path}'
132
+ try:
133
+ result = subprocess.run(cmd, shell=True, capture_output=True, check=True)
134
+ info = json.loads(result.stdout)
135
+ except Exception as e:
136
+ print(f"Error processing file {video_path}: {e}")
137
+ return {}
138
+
139
+ width = info['streams'][0]['width']
140
+ height = info['streams'][0]['height']
141
+ bitrate = info['streams'][0].get('bit_rate', 0)
142
+ bitdepth = info['streams'][0].get('bits_per_raw_sample', 0)
143
+ framerate = info['streams'][0]['r_frame_rate']
144
+ framerate = parse_framerate(framerate)
145
+ return width, height, bitrate, bitdepth, framerate
146
+
147
+ def parse_arguments():
148
+ parser = argparse.ArgumentParser()
149
+ parser.add_argument('--device', type=str, default='gpu', help='cpu or gpu')
150
+ parser.add_argument('--model_name', type=str, default='Mlp')
151
+ parser.add_argument('--select_criteria', type=str, default='byrmse')
152
+ parser.add_argument('--intra_cross_experiment', type=str, default='cross', help='intra or cross')
153
+ parser.add_argument('--is_finetune', type=bool, default=True, help='True or False')
154
+ parser.add_argument('--save_model_path', type=str, default='../model/')
155
+ parser.add_argument('--prompt_path', type=str, default="./config/prompts.json")
156
+
157
+ parser.add_argument('--train_data_name', type=str, default='lsvq_train', help='Name of the training data')
158
+ parser.add_argument('--test_data_name', type=str, default='finevd', help='Name of the testing data')
159
+ parser.add_argument('--test_video_path', type=str, default='../test_videos/0_16_07_500001604801190-yase.mp4', help='demo test video')
160
+ parser.add_argument('--prediction_mode', type=float, default=50, help='default for inference')
161
+
162
+ parser.add_argument('--network_name', type=str, default='camp-vqa')
163
+ parser.add_argument('--num_workers', type=int, default=4)
164
+ parser.add_argument('--resize', type=int, default=224)
165
+ parser.add_argument('--patch_size', type=int, default=16)
166
+ parser.add_argument('--target_size', type=int, default=224)
167
+ args = parser.parse_args()
168
+ return args
169
+
170
+ if __name__ == '__main__':
171
+ config = parse_arguments()
172
+ device = setup_device(config)
173
+ prompts = load_prompts(config.prompt_path)
174
+
175
+ # test demo video
176
+ resize_transform = get_transform(config.resize)
177
+ top_n = int(config.target_size /config. patch_size) * int(config.target_size / config.patch_size)
178
+
179
+ width, height, bitrate, bitdepth, framerate = get_video_metadata(config.test_video_path)
180
+
181
+ data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]],
182
+ 'test_data_name': [config.test_data_name],
183
+ 'test_video_path': [config.test_video_path],
184
+ 'prediction_mode': [config.prediction_mode],
185
+ 'width': [width], 'height': [height], 'bitrate': [bitrate], 'bitdepth': [bitdepth], 'framerate': [framerate]}
186
+ videos_dir = os.path.dirname(config.test_video_path)
187
+ test_df = pd.DataFrame(data)
188
+ print(test_df.T)
189
+ print(f"Experiment Setting: {config.intra_cross_experiment}, {config.train_data_name} -> {config.test_data_name}")
190
+ if config.intra_cross_experiment == 'cross':
191
+ if config.train_data_name == 'lsvq_train':
192
+ print(f"Fine-tune: {config.is_finetune}")
193
+
194
+ dataset = VideoDataset_feature(test_df, videos_dir, config.test_data_name, resize_transform, config.resize, config.patch_size, config.target_size, top_n)
195
+
196
+ data_loader = torch.utils.data.DataLoader(
197
+ dataset, batch_size=1, shuffle=False, num_workers = min(config.num_workers, os.cpu_count() or 1), pin_memory = device.type == "cuda"
198
+ )
199
+ print(f"Model: {config.network_name} | Dataset: {config.test_data_name} | Device: {device}")
200
+
201
+ # load models to device
202
+ model_slowfast = SlowFast().to(device)
203
+ model_swint = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
204
+
205
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
206
+ blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl", use_fast=True)
207
+ blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
208
+
209
+ input_features = 13056
210
+ if config.intra_cross_experiment == 'intra':
211
+ if config.train_data_name == 'lsvq_train':
212
+ from model_regression_lsvq import Mlp, preprocess_data
213
+ else:
214
+ from model_regression import Mlp, preprocess_data
215
+ elif config.intra_cross_experiment == 'cross':
216
+ from model_regression_lsvq import Mlp, preprocess_data
217
+ model_mlp = load_model(config, device, Mlp, input_features)
218
+
219
+ quality_prediction = evaluate_video_quality(preprocess_data, data_loader, model_slowfast, model_swint, clip_model, clip_preprocess, blip_processor, blip_model, prompts, model_mlp, device)
220
+ print("Predicted Quality Score:", quality_prediction)
model_finetune.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pandas as pd
3
+ import numpy as np
4
+ import math
5
+ import os
6
+ import scipy.io
7
+ import scipy.stats
8
+ from scipy.optimize import curve_fit
9
+ from sklearn.model_selection import train_test_split
10
+ import seaborn as sns
11
+ import matplotlib.pyplot as plt
12
+ import copy
13
+ from joblib import dump, load
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.optim as optim
17
+ from torch.optim.lr_scheduler import CosineAnnealingLR
18
+ from torch.optim.swa_utils import AveragedModel, SWALR
19
+ from torch.utils.data import DataLoader, TensorDataset
20
+ from model_regression_lsvq import Mlp, MAEAndRankLoss, preprocess_data, compute_correlation_metrics, logistic_func, plot_results
21
+
22
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+ if device.type == "cuda":
24
+ torch.cuda.set_device(0)
25
+
26
+ def create_results_dataframe(data_list, network_name, srcc_list, krcc_list, plcc_list, rmse_list, select_criteria_list):
27
+ df_results = pd.DataFrame(columns=['DATASET', 'MODEL', 'SRCC', 'KRCC', 'PLCC', 'RMSE', 'SELECT_CRITERIC'])
28
+ df_results['DATASET'] = data_list
29
+ df_results['MODEL'] = network_name
30
+ df_results['SRCC'] = srcc_list
31
+ df_results['KRCC'] = krcc_list
32
+ df_results['PLCC'] = plcc_list
33
+ df_results['RMSE'] = rmse_list
34
+ df_results['SELECT_CRITERIC'] = select_criteria_list
35
+ return df_results
36
+
37
+ def process_test_set(test_data_name, metadata_path, feature_path, network_name):
38
+ test_df = pd.read_csv(f'{metadata_path}/{test_data_name.upper()}_metadata.csv')
39
+
40
+ test_vids = test_df['vid']
41
+ mos = torch.tensor(test_df['mos'].astype(float), dtype=torch.float32)
42
+ if test_data_name in ('konvid_1k', 'youtube_ugc_h264'):
43
+ test_scores = ((mos - 1) * (99 / 4) + 1.0)
44
+ else:
45
+ test_scores = mos
46
+
47
+ sorted_test_df = pd.DataFrame({
48
+ 'vid': test_df['vid'],
49
+ 'framerate': test_df['framerate'],
50
+ 'MOS': test_scores,
51
+ 'MOS_raw': mos
52
+ })
53
+ test_features = torch.load(f'{feature_path}/{network_name}_{test_data_name}_features.pt')
54
+ print(f'num of {test_data_name} features: {len(test_features)}')
55
+ return test_features, test_vids, test_scores, sorted_test_df
56
+
57
+ def fix_state_dict(state_dict):
58
+ new_state_dict = {}
59
+ for k, v in state_dict.items():
60
+ if k.startswith('module.'):
61
+ name = k[7:]
62
+ elif k == 'n_averaged':
63
+ continue
64
+ else:
65
+ name = k
66
+ new_state_dict[name] = v
67
+ return new_state_dict
68
+
69
+ def collate_to_device(batch, device):
70
+ data, targets = zip(*batch)
71
+ return torch.stack(data).to(device), torch.stack(targets).to(device)
72
+
73
+ def model_test(best_model, X, y, device):
74
+ test_dataset = TensorDataset(X, y)
75
+ test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
76
+
77
+ best_model.eval()
78
+ y_pred = []
79
+ with torch.no_grad():
80
+ for inputs, _ in test_loader:
81
+ inputs = inputs.to(device)
82
+
83
+ outputs = best_model(inputs)
84
+ y_pred.extend(outputs.view(-1).tolist())
85
+ return y_pred
86
+
87
+ def fine_tune_model(model, device, model_path, X_fine_tune, y_fine_tune, save_path, batch_size, epochs, loss_type, optimizer_type, initial_lr, weight_decay, use_swa, l1_w, rank_w):
88
+ state_dict = torch.load(model_path)
89
+ fixed_state_dict = fix_state_dict(state_dict)
90
+ try:
91
+ model.load_state_dict(fixed_state_dict)
92
+ except RuntimeError as e:
93
+ print(e)
94
+
95
+ for param in model.parameters():
96
+ param.requires_grad = True
97
+ model.train().to(device) # to gpu
98
+
99
+ fine_tune_dataset = TensorDataset(X_fine_tune, y_fine_tune)
100
+ fine_tune_loader = DataLoader(dataset=fine_tune_dataset, batch_size=batch_size, shuffle=False)
101
+
102
+ # initialisation of loss function, optimiser
103
+ if loss_type == 'MAERankLoss':
104
+ criterion = MAEAndRankLoss()
105
+ criterion.l1_w = l1_w
106
+ criterion.rank_w = rank_w
107
+ else:
108
+ criterion = nn.MSELoss()
109
+
110
+ if optimizer_type == 'sgd':
111
+ optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=weight_decay)
112
+ scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)# initial eta_min=1e-5
113
+ else:
114
+ optimizer = optim.AdamW(model.parameters(), lr=initial_lr, weight_decay=weight_decay) # L2 Regularisation initial: 0.01, 1e-5
115
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) # step_size=10, gamma=0.1: every 10 epochs lr*0.1
116
+ if use_swa:
117
+ swa_model = AveragedModel(model).to(device)
118
+ swa_scheduler = SWALR(optimizer, swa_lr=initial_lr, anneal_strategy='cos')
119
+ swa_start = int(epochs * 0.75) if use_swa else epochs # SWA starts after 75% of total epochs, only set SWA start if SWA is used
120
+
121
+ best_loss = float('inf')
122
+ for epoch in range(epochs):
123
+ model.train()
124
+ epoch_loss = 0.0
125
+ for inputs, labels in fine_tune_loader:
126
+ inputs, labels = inputs.to(device), labels.to(device)
127
+ optimizer.zero_grad()
128
+ outputs = model(inputs)
129
+ loss = criterion(outputs, labels.view(-1, 1))
130
+ loss.backward()
131
+ optimizer.step()
132
+ epoch_loss += loss.item() * inputs.size(0)
133
+
134
+ scheduler.step()
135
+ if use_swa and epoch >= swa_start:
136
+ swa_model.update_parameters(model)
137
+ swa_scheduler.step()
138
+ print(f"Current learning rate with SWA: {swa_scheduler.get_last_lr()}")
139
+ avg_loss = epoch_loss / len(fine_tune_loader.dataset)
140
+ if (epoch + 1) % 5 == 0:
141
+ print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")
142
+
143
+ # decide which model to evaluate: SWA model or regular model
144
+ current_model = swa_model if use_swa and epoch >= swa_start else model
145
+ # Save best model state
146
+ if avg_loss < best_loss:
147
+ best_loss = avg_loss
148
+ best_model = copy.deepcopy(current_model)
149
+
150
+ # decide which model to evaluate: SWA model or regular model
151
+ if use_swa and epoch >= swa_start:
152
+ train_loader = DataLoader(dataset=fine_tune_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_to_device(x, device))
153
+ best_model = best_model.to(device)
154
+ best_model.eval()
155
+ torch.optim.swa_utils.update_bn(train_loader, best_model)
156
+ # model_path_new = os.path.join(save_path, f"{test_data_name}_diva-vqa_fine_tuned_model.pth")
157
+ # torch.save(best_model.state_dict(), model_path_new) # save finetuned model
158
+ return best_model
159
+
160
+ def fine_tuned_model_test(model, device, X_test, y_test, test_data_name):
161
+ model.eval()
162
+ y_test_pred = model_test(model, X_test, y_test, device)
163
+ y_test_pred = torch.tensor(list(y_test_pred), dtype=torch.float32)
164
+ if test_data_name in ('konvid_1k', 'youtube_ugc_h264'):
165
+ y_test_convert = ((y_test - 1) / (99 / 4) + 1.0)
166
+ y_test_pred_convert = ((y_test_pred - 1) / (99 / 4) + 1.0)
167
+ else:
168
+ y_test_convert = y_test
169
+ y_test_pred_convert = y_test_pred
170
+
171
+ y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test_convert.cpu().numpy(), y_test_pred_convert.cpu().numpy())
172
+ test_pred_score = {'MOS': y_test_convert, 'y_test_pred': y_test_pred_convert, 'y_test_pred_logistic': y_test_pred_logistic}
173
+ df_test_pred = pd.DataFrame(test_pred_score)
174
+ return df_test_pred, y_test_convert, y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test
175
+
176
+ def wo_fine_tune_model(model, device, model_path, X_test, y_test, loss_type, test_data_name):
177
+ state_dict = torch.load(model_path)
178
+ fixed_state_dict = fix_state_dict(state_dict)
179
+ try:
180
+ model.load_state_dict(fixed_state_dict)
181
+ except RuntimeError as e:
182
+ print(e)
183
+ model.eval().to(device) # to gpu
184
+
185
+ if loss_type == 'MAERankLoss':
186
+ criterion = MAEAndRankLoss()
187
+ else:
188
+ criterion = torch.nn.MSELoss()
189
+
190
+ # evaluate the model
191
+ test_dataset = TensorDataset(X_test, y_test)
192
+ test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
193
+ test_loss = 0.0
194
+ for inputs, labels in test_loader:
195
+ inputs, labels = inputs.to(device), labels.to(device)
196
+ outputs = model(inputs)
197
+ loss = criterion(outputs, labels.view(-1, 1))
198
+ test_loss += loss.item() * inputs.size(0)
199
+ average_loss = test_loss / len(test_loader.dataset)
200
+ print(f"Test Loss: {average_loss}")
201
+
202
+ y_test_pred = model_test(model, X_test, y_test, device)
203
+ y_test_pred = torch.tensor(list(y_test_pred), dtype=torch.float32)
204
+ if test_data_name in ('konvid_1k', 'youtube_ugc_h264'):
205
+ y_test_convert = ((y_test - 1) / (99 / 4) + 1.0)
206
+ y_test_pred_convert = ((y_test_pred - 1) / (99 / 4) + 1.0)
207
+ else:
208
+ y_test_convert = y_test
209
+ y_test_pred_convert = y_test_pred
210
+
211
+ y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test_convert.cpu().numpy(), y_test_pred_convert.cpu().numpy())
212
+ test_pred_score = {'MOS': y_test_convert, 'y_test_pred': y_test_pred_convert, 'y_test_pred_logistic': y_test_pred_logistic}
213
+ df_test_pred = pd.DataFrame(test_pred_score)
214
+ return df_test_pred, y_test_convert, y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test
215
+
216
+ def run(args):
217
+ data_list, srcc_list, krcc_list, plcc_list, rmse_list, select_criteria_list = [], [], [], [], [], []
218
+
219
+ os.makedirs(os.path.join(args.report_path, 'fine_tune'), exist_ok=True)
220
+ if args.is_finetune:
221
+ csv_name = f'{args.report_path}/fine_tune/{args.test_data_name}_{args.network_name}_{args.select_criteria}_finetune.csv'
222
+ else:
223
+ csv_name = f'{args.report_path}/fine_tune/{args.test_data_name}_{args.network_name}_{args.select_criteria}_wo_finetune.csv'
224
+ print(f'Test dataset: {args.test_data_name}')
225
+ test_features, test_vids, test_scores, sorted_test_df = process_test_set(args.test_data_name, args.metadata_path, args.feature_path, args.network_name)
226
+ X_test, y_test = preprocess_data(test_features, test_scores)
227
+
228
+ # get save model param
229
+ model = Mlp(input_features=X_test.shape[1], out_features=1, drop_rate=0.2, act_layer=nn.GELU)
230
+ model = model.to(device)
231
+ model_path = os.path.join(args.model_path, f"{args.train_data_name}_{args.network_name}_{args.model_name}_{args.select_criteria}_trained_model_kfold.pth")
232
+
233
+ model_results = []
234
+ for i in range(1, args.n_repeats + 1):
235
+ print(f"{i}th repeated 80-20 hold out test")
236
+ X_fine_tune, X_final_test, y_fine_tune, y_final_test = train_test_split(X_test, y_test, test_size=0.2, random_state=math.ceil(8.8 * i))
237
+ if args.is_finetune:
238
+ # test fine tuned model on the test dataset
239
+ ft_model = fine_tune_model(model, device, model_path, X_fine_tune, y_fine_tune, args.report_path, args.batch_size,
240
+ args.epochs, args.loss_type, args.optimizer_type, args.initial_lr, args.weight_decay, args.use_swa, args.l1_w, args.rank_w)
241
+ df_test_pred, y_test_convert, y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = fine_tuned_model_test(ft_model, device, X_final_test, y_final_test, args.test_data_name)
242
+ best_model = copy.deepcopy(ft_model)
243
+ else:
244
+ # without fine tune on the test dataset
245
+ df_test_pred, y_test_convert, y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = wo_fine_tune_model(model, device, model_path, X_test, y_test, args.loss_type, args.test_data_name)
246
+ print(y_test_pred_logistic)
247
+ best_model = copy.deepcopy(model)
248
+
249
+ model_results.append({
250
+ 'model': best_model,
251
+ 'srcc': srcc_test,
252
+ 'krcc': krcc_test,
253
+ 'plcc': plcc_test,
254
+ 'rmse': rmse_test,
255
+ 'df_pred': df_test_pred
256
+ })
257
+ print('\n')
258
+
259
+ if args.select_criteria == 'byrmse':
260
+ sorted_results = sorted(model_results, key=lambda x: x['rmse'])
261
+ elif args.select_criteria == 'bykrcc':
262
+ sorted_results = sorted(model_results, key=lambda x: x['krcc'], reverse=True)
263
+ else:
264
+ raise ValueError(f"Unknown select_criteria: {args.select_criteria}")
265
+ median_index = len(sorted_results) // 2
266
+ median_result = sorted_results[median_index]
267
+ median_model = median_result['model']
268
+ median_df_test_pred = median_result['df_pred']
269
+ median_srcc_test = median_result['srcc']
270
+ median_krcc_test = median_result['krcc']
271
+ median_plcc_test = median_result['plcc']
272
+ median_rmse_test = median_result['rmse']
273
+ data_list.append(args.test_data_name)
274
+ srcc_list.append(median_srcc_test)
275
+ krcc_list.append(median_krcc_test)
276
+ plcc_list.append(median_plcc_test)
277
+ rmse_list.append(median_rmse_test)
278
+ select_criteria_list.append(args.select_criteria)
279
+ median_df_test_pred.head()
280
+
281
+ # save finetuned model
282
+ if args.is_finetune:
283
+ model_path_new = os.path.join(args.report_path, f"{args.test_data_name}_{args.network_name}_fine_tuned_model.pth")
284
+ torch.save(median_model.state_dict(), model_path_new)
285
+ print(f"Median model select {args.select_criteria} saved to {model_path_new}")
286
+
287
+ df_results = create_results_dataframe(data_list, args.network_name, srcc_list, krcc_list, plcc_list, rmse_list, select_criteria_list)
288
+ print(df_results.T)
289
+ df_results.to_csv(csv_name, index=None, encoding="UTF-8")
290
+
291
+ if __name__ == '__main__':
292
+ parser = argparse.ArgumentParser()
293
+
294
+ # input parameters
295
+ parser.add_argument('--train_data_name', type=str, default='lsvq_train')
296
+ parser.add_argument('--test_data_name', type=str, default='finevd')
297
+ parser.add_argument('--network_name', type=str, default='camp-vqa')
298
+ parser.add_argument('--model_name', type=str, default='Mlp')
299
+ parser.add_argument('--select_criteria', type=str, default='byrmse', choices=['byrmse', 'bykrcc'])
300
+
301
+ # paths
302
+ parser.add_argument('--metadata_path', type=str, default='../metadata/')
303
+ parser.add_argument('--feature_path', type=str, default=None)
304
+ parser.add_argument('--model_path', type=str, default='../model/')
305
+ parser.add_argument('--report_path', type=str, default='../log/')
306
+
307
+ # training parameters
308
+ parser.add_argument('--is_finetune', action='store_true', help="Enable fine-tuning")
309
+ parser.add_argument('--n_repeats', type=int, default=21)
310
+ parser.add_argument('--batch_size', type=int, default=256)
311
+ parser.add_argument('--epochs', type=int, default=200)
312
+
313
+ # misc
314
+ parser.add_argument('--loss_type', type=str, default='MAERankLoss')
315
+ parser.add_argument('--optimizer_type', type=str, default='sgd')
316
+ parser.add_argument('--initial_lr', type=float, default=1e-2)
317
+ parser.add_argument('--weight_decay', type=float, default=0.0005)
318
+ parser.add_argument('--use_swa', type=bool, default=True, help="Enable SWA (default: True)")
319
+ parser.add_argument('--l1_w', type=float, default=0.6)
320
+ parser.add_argument('--rank_w', type=float, default=1.0)
321
+
322
+ args = parser.parse_args()
323
+ if args.feature_path is None:
324
+ args.feature_path = f'../features/{args.network_name}/'
325
+ print(f"[Paths] metadata: {args.metadata_path}; features: {args.feature_path}; model: {args.model_path}; report: {args.report_path}")
326
+ run(args)
model_regression.py ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import os
4
+ import pandas as pd
5
+ import numpy as np
6
+ import math
7
+ import scipy.io
8
+ import scipy.stats
9
+ from sklearn.impute import SimpleImputer
10
+ from sklearn.preprocessing import MinMaxScaler
11
+ from sklearn.metrics import mean_squared_error
12
+ from scipy.optimize import curve_fit
13
+ import joblib
14
+
15
+ import seaborn as sns
16
+ import matplotlib.pyplot as plt
17
+ import copy
18
+ import argparse
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.optim as optim
24
+ from torch.optim.lr_scheduler import CosineAnnealingLR
25
+ from torch.optim.swa_utils import AveragedModel, SWALR
26
+ from torch.utils.data import DataLoader, TensorDataset
27
+ from sklearn.model_selection import train_test_split
28
+
29
+ from data_processing import split_train_test
30
+
31
+ # ignore all warnings
32
+ import warnings
33
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
34
+
35
+
36
+ class Mlp(nn.Module):
37
+ def __init__(self, input_features, hidden_features=256, out_features=1, drop_rate=0.2, act_layer=nn.GELU):
38
+ super().__init__()
39
+ self.fc1 = nn.Linear(input_features, hidden_features)
40
+ # self.bn1 = nn.BatchNorm1d(hidden_features)
41
+ self.act1 = act_layer()
42
+ self.drop1 = nn.Dropout(drop_rate)
43
+ self.fc2 = nn.Linear(hidden_features, hidden_features // 2)
44
+ self.act2 = act_layer()
45
+ self.drop2 = nn.Dropout(drop_rate)
46
+ self.fc3 = nn.Linear(hidden_features // 2, out_features)
47
+
48
+ def forward(self, input_feature):
49
+ x = self.fc1(input_feature)
50
+ # x = self.bn1(x)
51
+ x = self.act1(x)
52
+ x = self.drop1(x)
53
+ x = self.fc2(x)
54
+ x = self.act2(x)
55
+ x = self.drop2(x)
56
+ output = self.fc3(x)
57
+ return output
58
+
59
+
60
+ class MAEAndRankLoss(nn.Module):
61
+ def __init__(self, l1_w=1.0, rank_w=1.0, margin=0.0, use_margin=False):
62
+ super(MAEAndRankLoss, self).__init__()
63
+ self.l1_w = l1_w
64
+ self.rank_w = rank_w
65
+ self.margin = margin
66
+ self.use_margin = use_margin
67
+
68
+ def forward(self, y_pred, y_true):
69
+ # L1 loss/MAE loss
70
+ l_mae = F.l1_loss(y_pred, y_true, reduction='mean') * self.l1_w
71
+ # Rank loss
72
+ n = y_pred.size(0)
73
+ pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0)
74
+ true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0)
75
+
76
+ # e(ytrue_i, ytrue_j)
77
+ masks = torch.sign(true_diff)
78
+
79
+ if self.use_margin and self.margin > 0:
80
+ true_diff = true_diff.abs() - self.margin
81
+ true_diff = F.relu(true_diff)
82
+ masks = true_diff.sign()
83
+
84
+ l_rank = F.relu(true_diff - masks * pred_diff)
85
+ l_rank = l_rank.sum() / (n * (n - 1))
86
+
87
+ loss = l_mae + l_rank * self.rank_w
88
+ return loss
89
+
90
+ def load_data(csv, data, data_name, set_name):
91
+ try:
92
+ df = pd.read_csv(csv, skiprows=[], header=None)
93
+ except Exception as e:
94
+ logging.error(f'Read CSV file error: {e}')
95
+ raise
96
+
97
+ y_data = df.values[1:, 2].astype(float)
98
+ y = torch.tensor(y_data, dtype=torch.float32)
99
+
100
+ if set_name == 'test':
101
+ print(f"Modified y_true: {y}")
102
+ X = data
103
+ return X, y
104
+
105
+ def preprocess_data(X, y):
106
+ X[torch.isnan(X)] = 0
107
+ X[torch.isinf(X)] = 0
108
+
109
+ # MinMaxScaler (use PyTorch implementation)
110
+ X_min = X.min(dim=0, keepdim=True).values
111
+ X_max = X.max(dim=0, keepdim=True).values
112
+ X = (X - X_min) / (X_max - X_min)
113
+ y = y.view(-1, 1).squeeze()
114
+ return X, y
115
+
116
+ # define 4-parameter logistic regression
117
+ def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
118
+ logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
119
+ yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
120
+ return yhat
121
+
122
+ def fit_logistic_regression(y_pred, y_true):
123
+ beta = [np.max(y_true), np.min(y_true), np.mean(y_pred), 0.5]
124
+ popt, _ = curve_fit(logistic_func, y_pred, y_true, p0=beta, maxfev=100000000)
125
+ y_pred_logistic = logistic_func(y_pred, *popt)
126
+ return y_pred_logistic, beta, popt
127
+
128
+ def compute_correlation_metrics(y_true, y_pred):
129
+ y_pred_logistic, beta, popt = fit_logistic_regression(y_pred, y_true)
130
+
131
+ plcc = scipy.stats.pearsonr(y_true, y_pred_logistic)[0]
132
+ rmse = np.sqrt(mean_squared_error(y_true, y_pred_logistic))
133
+ srcc = scipy.stats.spearmanr(y_true, y_pred)[0]
134
+
135
+ try:
136
+ krcc = scipy.stats.kendalltau(y_true, y_pred)[0]
137
+ except Exception as e:
138
+ logging.error(f'krcc calculation: {e}')
139
+ krcc = scipy.stats.kendalltau(y_true, y_pred, method='asymptotic')[0]
140
+ return y_pred_logistic, plcc, rmse, srcc, krcc
141
+
142
+ def plot_results(y_test, y_test_pred_logistic, df_pred_score, model_name, data_name, network_name, select_criteria):
143
+ # nonlinear logistic fitted curve / logistic regression
144
+ if isinstance(y_test, torch.Tensor):
145
+ mos1 = y_test.numpy()
146
+ y1 = y_test_pred_logistic
147
+
148
+ try:
149
+ beta = [np.max(mos1), np.min(mos1), np.mean(y1), 0.5]
150
+ popt, pcov = curve_fit(logistic_func, y1, mos1, p0=beta, maxfev=100000000)
151
+ sigma = np.sqrt(np.diag(pcov))
152
+ except:
153
+ raise Exception('Fitting logistic function time-out!!')
154
+ x_values1 = np.linspace(np.min(y1), np.max(y1), len(y1))
155
+ plt.plot(x_values1, logistic_func(x_values1, *popt), '-', color='#c72e29', label='Fitted f(x)')
156
+
157
+ fig1 = sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df_pred_score, markers='o', color='steelblue', label=network_name)
158
+ plt.legend(loc='upper left')
159
+ if data_name == 'live_vqc' or data_name == 'live_qualcomm' or data_name == 'cvd_2014' or data_name == 'lsvq_train' or data_name == 'live_yt_gaming' or data_name == "finevd":
160
+ plt.ylim(0, 100)
161
+ plt.xlim(0, 100)
162
+ else:
163
+ plt.ylim(1, 5)
164
+ plt.xlim(1, 5)
165
+ plt.title(f"Algorithm {network_name} with {model_name} on dataset {data_name}", fontsize=10)
166
+ plt.xlabel('Predicted Score')
167
+ plt.ylabel('MOS')
168
+ reg_fig1 = fig1.get_figure()
169
+
170
+ fig_path = f'../figs/{data_name}/'
171
+ os.makedirs(fig_path, exist_ok=True)
172
+ reg_fig1.savefig(fig_path + f"{network_name}_{model_name}_{data_name}_{select_criteria}.png", dpi=300)
173
+ plt.clf()
174
+ plt.close()
175
+
176
+ def plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, test_vids, i):
177
+ plt.figure(figsize=(10, 6))
178
+
179
+ plt.plot(avg_train_losses, label='Average Training Loss')
180
+ plt.plot(avg_val_losses, label='Average Validation Loss')
181
+
182
+ plt.xlabel('Epoch')
183
+ plt.ylabel('Loss')
184
+ plt.title(f'Average Training and Validation Loss Across Folds - {network_name} with {model_name} (test_vids: {test_vids})', fontsize=10)
185
+
186
+ plt.legend()
187
+ fig_par_path = f'../log/result/{data_name}/'
188
+ os.makedirs(fig_par_path, exist_ok=True)
189
+ plt.savefig(f'{fig_par_path}/{network_name}_Average_Training_Loss_test{i}.png', dpi=50)
190
+ plt.clf()
191
+ plt.close()
192
+
193
+ def configure_logging(log_path, model_name, data_name, network_name, select_criteria):
194
+ log_file_name = os.path.join(log_path, f"{data_name}_{network_name}_{model_name}_{select_criteria}.log")
195
+ logging.basicConfig(filename=log_file_name, filemode='w', level=logging.DEBUG, format='%(levelname)s - %(message)s')
196
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
197
+ logging.info(f"Evaluating algorithm {network_name} with {model_name} on dataset {data_name}")
198
+ logging.info(f"torch cuda: {torch.cuda.is_available()}")
199
+
200
+ def load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features):
201
+ if data_name == 'lsvq_train':
202
+ train_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
203
+ test_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
204
+ X_train, y_train = load_data(train_csv, train_features, data_name, 'train')
205
+ X_test, y_test = load_data(test_csv, test_features, data_name, 'test')
206
+
207
+ else:
208
+ train_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
209
+ test_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
210
+ train_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name}_train_features.pt')
211
+ test_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name}_test_features.pt')
212
+ X_train, y_train = load_data(train_csv, train_data, data_name, 'train')
213
+ X_test, y_test = load_data(test_csv, test_data, data_name, 'test')
214
+
215
+ # standard min-max normalization of training features
216
+ X_train, y_train = preprocess_data(X_train, y_train)
217
+ X_test, y_test = preprocess_data(X_test, y_test)
218
+
219
+ return X_train, y_train, X_test, y_test
220
+
221
+ def train_one_epoch(model, train_loader, criterion, optimizer, device):
222
+ """Train the model for one epoch"""
223
+ model.train()
224
+ train_loss = 0.0
225
+ for inputs, targets in train_loader:
226
+ inputs, targets = inputs.to(device), targets.to(device)
227
+
228
+ optimizer.zero_grad()
229
+ outputs = model(inputs)
230
+ loss = criterion(outputs, targets.view(-1, 1))
231
+ loss.backward()
232
+ optimizer.step()
233
+ train_loss += loss.item() * inputs.size(0)
234
+ train_loss /= len(train_loader.dataset)
235
+ return train_loss
236
+
237
+ def evaluate(model, val_loader, criterion, device):
238
+ """Evaluate model performance on validation sets"""
239
+ model.eval()
240
+ val_loss = 0.0
241
+ y_val_pred = []
242
+ y_val_true = []
243
+ with torch.no_grad():
244
+ for inputs, targets in val_loader:
245
+ inputs, targets = inputs.to(device), targets.to(device)
246
+
247
+ outputs = model(inputs)
248
+ y_val_pred.append(outputs)
249
+ y_val_true.append(targets)
250
+ loss = criterion(outputs, targets.view(-1, 1))
251
+ val_loss += loss.item() * inputs.size(0)
252
+
253
+ val_loss /= len(val_loader.dataset)
254
+ y_val_pred = torch.cat(y_val_pred, dim=0)
255
+ y_val_true = torch.cat(y_val_true, dim=0)
256
+ return val_loss, y_val_pred, y_val_true
257
+
258
+ def update_best_model(select_criteria, best_metric, current_val, model):
259
+ is_better = False
260
+ if select_criteria == 'byrmse' and current_val < best_metric:
261
+ is_better = True
262
+ elif select_criteria == 'bykrcc' and current_val > best_metric:
263
+ is_better = True
264
+
265
+ if is_better:
266
+ return current_val, copy.deepcopy(model), is_better
267
+ return best_metric, model, is_better
268
+
269
+ def train_and_evaluate(X_train, y_train, config):
270
+ # parameters
271
+ n_repeats = config['n_repeats']
272
+ batch_size = config['batch_size']
273
+ epochs = config['epochs']
274
+ hidden_features = config['hidden_features']
275
+ drop_rate = config['drop_rate']
276
+ loss_type = config['loss_type']
277
+ optimizer_type = config['optimizer_type']
278
+ select_criteria = config['select_criteria']
279
+ initial_lr = config['initial_lr']
280
+ weight_decay = config['weight_decay']
281
+ patience = config['patience']
282
+ l1_w = config['l1_w']
283
+ rank_w = config['rank_w']
284
+ use_swa = config.get('use_swa', False)
285
+ logging.info(f'Parameters - Number of repeats for 80-20 hold out test: {n_repeats}, Batch size: {batch_size}, Number of epochs: {epochs}')
286
+ logging.info(f'Network Parameters - hidden_features: {hidden_features}, drop_rate: {drop_rate}, patience: {patience}')
287
+ logging.info(f'Optimizer Parameters - loss_type: {loss_type}, optimizer_type: {optimizer_type}, initial_lr: {initial_lr}, weight_decay: {weight_decay}, use_swa: {use_swa}')
288
+ logging.info(f'MAEAndRankLoss - l1_w: {l1_w}, rank_w: {rank_w}')
289
+
290
+ # Split data into train and validation
291
+ X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
292
+ best_model = None
293
+ best_metric = float('inf') if select_criteria == 'byrmse' else float('-inf')
294
+
295
+ # loss for every fold
296
+ all_train_losses = []
297
+ all_val_losses = []
298
+
299
+ # initialisation of model, loss function, optimiser
300
+ model = Mlp(input_features=X_train.shape[1], hidden_features=hidden_features, drop_rate=drop_rate)
301
+ model = model.to(device) # to gpu
302
+
303
+ if loss_type == 'MAERankLoss':
304
+ criterion = MAEAndRankLoss()
305
+ criterion.l1_w = l1_w
306
+ criterion.rank_w = rank_w
307
+ else:
308
+ criterion = nn.MSELoss()
309
+
310
+ if optimizer_type == 'sgd':
311
+ optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=weight_decay)
312
+ scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)# initial eta_min=1e-5
313
+ else:
314
+ optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) # L2 Regularisation initial: 0.01, 1e-5
315
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) # step_size=10, gamma=0.1: every 10 epochs lr*0.1
316
+ if use_swa:
317
+ swa_model = AveragedModel(model).to(device)
318
+ swa_scheduler = SWALR(optimizer, swa_lr=initial_lr, anneal_strategy='cos')
319
+
320
+ # dataset loader
321
+ train_dataset = TensorDataset(X_train, y_train)
322
+ val_dataset = TensorDataset(X_val, y_val)
323
+ train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
324
+ val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
325
+
326
+ train_losses, val_losses = [], []
327
+
328
+ # early stopping parameters
329
+ best_val_loss = float('inf')
330
+ epochs_no_improve = 0
331
+ early_stop_active = False
332
+ swa_start = int(epochs * 0.7) if use_swa else epochs # SWA starts after 70% of total epochs, only set SWA start if SWA is used
333
+
334
+ for epoch in range(epochs):
335
+ train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
336
+ train_losses.append(train_loss)
337
+ scheduler.step() # update learning rate
338
+ if use_swa and epoch >= swa_start:
339
+ swa_model.update_parameters(model)
340
+ swa_scheduler.step()
341
+ early_stop_active = True
342
+ print(f"Current learning rate with SWA: {swa_scheduler.get_last_lr()}")
343
+
344
+ lr = optimizer.param_groups[0]['lr']
345
+ print('Epoch %d: Learning rate: %f' % (epoch + 1, lr))
346
+
347
+ # decide which model to evaluate: SWA model or regular model
348
+ current_model = swa_model if use_swa and epoch >= swa_start else model
349
+ current_model.eval()
350
+ val_loss, y_val_pred, y_val_true = evaluate(current_model, val_loader, criterion, device)
351
+ val_losses.append(val_loss)
352
+ print(f"Epoch {epoch + 1}, Training Loss: {train_loss}, Validation Loss: {val_loss}")
353
+
354
+ y_val_pred = torch.cat([pred for pred in y_val_pred])
355
+ _, _, rmse_val, _, krcc_val = compute_correlation_metrics(y_val.cpu().numpy(), y_val_pred.cpu().numpy())
356
+ current_metric = rmse_val if select_criteria == 'byrmse' else krcc_val
357
+ best_metric, best_model, is_better = update_best_model(select_criteria, best_metric, current_metric, current_model)
358
+ if is_better:
359
+ logging.info(f"Epoch {epoch + 1}:")
360
+ y_val_pred_logistic_tmp, plcc_valid_tmp, rmse_valid_tmp, srcc_valid_tmp, krcc_valid_tmp = compute_correlation_metrics(y_val.cpu().numpy(), y_val_pred.cpu().numpy())
361
+ logging.info(f'Validation set - Evaluation Results - SRCC: {srcc_valid_tmp}, KRCC: {krcc_valid_tmp}, PLCC: {plcc_valid_tmp}, RMSE: {rmse_valid_tmp}')
362
+
363
+ X_train_fold_tensor = X_train
364
+ y_tra_pred_tmp = best_model(X_train_fold_tensor).detach().cpu().squeeze()
365
+ y_tra_pred_logistic_tmp, plcc_train_tmp, rmse_train_tmp, srcc_train_tmp, krcc_train_tmp = compute_correlation_metrics(y_train.cpu().numpy(), y_tra_pred_tmp.cpu().numpy())
366
+ logging.info(f'Train set - Evaluation Results - SRCC: {srcc_train_tmp}, KRCC: {krcc_train_tmp}, PLCC: {plcc_train_tmp}, RMSE: {rmse_train_tmp}')
367
+
368
+ # check for loss improvement
369
+ if early_stop_active:
370
+ if val_loss < best_val_loss:
371
+ best_val_loss = val_loss
372
+ # save the best model if validation loss improves
373
+ best_model = copy.deepcopy(model)
374
+ epochs_no_improve = 0
375
+ else:
376
+ epochs_no_improve += 1
377
+ if epochs_no_improve >= patience:
378
+ # epochs to wait for improvement before stopping
379
+ print(f"Early stopping triggered after {epoch + 1} epochs.")
380
+ break
381
+
382
+ # saving SWA models and updating BN statistics
383
+ if use_swa:
384
+ train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
385
+ best_model = best_model.to(device)
386
+ best_model.eval()
387
+ torch.optim.swa_utils.update_bn(train_loader, best_model)
388
+
389
+ all_train_losses.append(train_losses)
390
+ all_val_losses.append(val_losses)
391
+ max_length = max(len(x) for x in all_train_losses)
392
+ all_train_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_train_losses]
393
+ max_length = max(len(x) for x in all_val_losses)
394
+ all_val_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_val_losses]
395
+
396
+ return best_model, all_train_losses, all_val_losses
397
+
398
+ def collate_to_device(batch, device):
399
+ data, targets = zip(*batch)
400
+ return torch.stack(data).to(device), torch.stack(targets).to(device)
401
+
402
+ def model_test(best_model, X, y, device):
403
+ test_dataset = TensorDataset(X, y)
404
+ test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
405
+
406
+ best_model.eval()
407
+ y_pred = []
408
+ with torch.no_grad():
409
+ for inputs, _ in test_loader:
410
+ inputs = inputs.to(device)
411
+
412
+ outputs = best_model(inputs)
413
+ y_pred.extend(outputs.view(-1).tolist())
414
+
415
+ return y_pred
416
+
417
+ def main(config):
418
+ model_name = config['model_name']
419
+ data_name = config['data_name']
420
+ network_name = config['network_name']
421
+
422
+ metadata_path = config['metadata_path']
423
+ feature_path = config['feature_path']
424
+ log_path = config['log_path']
425
+ save_path = config['save_path']
426
+ score_path = config['score_path']
427
+ result_path = config['result_path']
428
+
429
+ # parameters
430
+ select_criteria = config['select_criteria']
431
+ n_repeats = config['n_repeats']
432
+
433
+ # logging and result
434
+ os.makedirs(log_path, exist_ok=True)
435
+ os.makedirs(save_path, exist_ok=True)
436
+ os.makedirs(score_path, exist_ok=True)
437
+ os.makedirs(result_path, exist_ok=True)
438
+ result_file = f'{result_path}{data_name}_{network_name}_{model_name}_{select_criteria}.mat'
439
+ pred_score_filename = os.path.join(score_path, f"{data_name}_{network_name}_{model_name}_Predicted_Score_{select_criteria}.csv")
440
+ file_path = os.path.join(save_path, f"{data_name}_{network_name}_{model_name}_{select_criteria}_trained_model.pth")
441
+ configure_logging(log_path, model_name, data_name, network_name, select_criteria)
442
+
443
+ '''======================== Main Body ==========================='''
444
+ PLCC_all_repeats_test = []
445
+ SRCC_all_repeats_test = []
446
+ KRCC_all_repeats_test = []
447
+ RMSE_all_repeats_test = []
448
+ PLCC_all_repeats_train = []
449
+ SRCC_all_repeats_train = []
450
+ KRCC_all_repeats_train = []
451
+ RMSE_all_repeats_train = []
452
+ all_repeats_test_vids = []
453
+ all_repeats_df_test_pred = []
454
+ best_model_list = []
455
+
456
+ for i in range(1, n_repeats + 1):
457
+ print(f"{i}th repeated 80-20 hold out test")
458
+ logging.info(f"{i}th repeated 80-20 hold out test")
459
+ t0 = time.time()
460
+
461
+ # train test split
462
+ test_size = 0.2
463
+ random_state = math.ceil(8.8 * i)
464
+ # NR: original
465
+ if data_name == 'lsvq_train':
466
+ test_data_name = 'lsvq_test' #lsvq_test, lsvq_test_1080p
467
+ train_features, test_features, test_vids = split_train_test.process_lsvq(data_name, test_data_name, metadata_path, feature_path, network_name)
468
+ else:
469
+ _, _, test_vids = split_train_test.process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name)
470
+
471
+ '''======================== read files =============================== '''
472
+ if data_name == 'lsvq_train':
473
+ X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features)
474
+ else:
475
+ X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, None, None)
476
+
477
+ '''======================== regression model =============================== '''
478
+ best_model, all_train_losses, all_val_losses = train_and_evaluate(X_train, y_train, config)
479
+
480
+ # average loss plots
481
+ avg_train_losses = np.mean(all_train_losses, axis=0)
482
+ avg_val_losses = np.mean(all_val_losses, axis=0)
483
+ test_vids = test_vids.tolist()
484
+ plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, len(test_vids), i)
485
+
486
+ # predict best model on the train dataset
487
+ y_train_pred = model_test(best_model, X_train, y_train, device)
488
+ y_train_pred = torch.tensor(list(y_train_pred), dtype=torch.float32)
489
+ y_train_pred_logistic, plcc_train, rmse_train, srcc_train, krcc_train = compute_correlation_metrics(y_train.cpu().numpy(), y_train_pred.cpu().numpy())
490
+
491
+ # test best model on the test dataset
492
+ y_test_pred = model_test(best_model, X_test, y_test, device)
493
+ y_test_pred = torch.tensor(list(y_test_pred), dtype=torch.float32)
494
+ y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test.cpu().numpy(), y_test_pred.cpu().numpy())
495
+
496
+ # save the predict score results
497
+ test_pred_score = {'MOS': y_test, 'y_test_pred': y_test_pred, 'y_test_pred_logistic': y_test_pred_logistic}
498
+ df_test_pred = pd.DataFrame(test_pred_score)
499
+
500
+ # logging logistic predicted scores
501
+ logging.info("============================================================================================================")
502
+ SRCC_all_repeats_test.append(srcc_test)
503
+ KRCC_all_repeats_test.append(krcc_test)
504
+ PLCC_all_repeats_test.append(plcc_test)
505
+ RMSE_all_repeats_test.append(rmse_test)
506
+ SRCC_all_repeats_train.append(srcc_train)
507
+ KRCC_all_repeats_train.append(krcc_train)
508
+ PLCC_all_repeats_train.append(plcc_train)
509
+ RMSE_all_repeats_train.append(rmse_train)
510
+ all_repeats_test_vids.append(test_vids)
511
+ all_repeats_df_test_pred.append(df_test_pred)
512
+ best_model_list.append(copy.deepcopy(best_model))
513
+
514
+ # logging.info results for each iteration
515
+ logging.info('Best results in Mlp model within one split')
516
+ logging.info(f'MODEL: {best_model}')
517
+ logging.info('======================================================')
518
+ logging.info(f'Train set - Evaluation Results')
519
+ logging.info(f'SRCC_train: {srcc_train}')
520
+ logging.info(f'KRCC_train: {krcc_train}')
521
+ logging.info(f'PLCC_train: {plcc_train}')
522
+ logging.info(f'RMSE_train: {rmse_train}')
523
+ logging.info('======================================================')
524
+ logging.info(f'Test set - Evaluation Results')
525
+ logging.info(f'SRCC_test: {srcc_test}')
526
+ logging.info(f'KRCC_test: {krcc_test}')
527
+ logging.info(f'PLCC_test: {plcc_test}')
528
+ logging.info(f'RMSE_test: {rmse_test}')
529
+ logging.info('======================================================')
530
+ logging.info(' -- {} seconds elapsed...\n\n'.format(time.time() - t0))
531
+
532
+ logging.info('')
533
+ SRCC_all_repeats_test = torch.tensor(SRCC_all_repeats_test, dtype=torch.float32)
534
+ KRCC_all_repeats_test = torch.tensor(KRCC_all_repeats_test, dtype=torch.float32)
535
+ PLCC_all_repeats_test = torch.tensor(PLCC_all_repeats_test, dtype=torch.float32)
536
+ PLCC_all_repeats_test = PLCC_all_repeats_test[~torch.isnan(PLCC_all_repeats_test)]
537
+ RMSE_all_repeats_test = torch.tensor(RMSE_all_repeats_test, dtype=torch.float32)
538
+ SRCC_all_repeats_train = torch.tensor(SRCC_all_repeats_train, dtype=torch.float32)
539
+ KRCC_all_repeats_train = torch.tensor(KRCC_all_repeats_train, dtype=torch.float32)
540
+ PLCC_all_repeats_train = torch.tensor(PLCC_all_repeats_train, dtype=torch.float32)
541
+ RMSE_all_repeats_train = torch.tensor(RMSE_all_repeats_train, dtype=torch.float32)
542
+
543
+ logging.info('======================================================')
544
+ logging.info('Average training results among all repeated 80-20 holdouts:')
545
+ logging.info('SRCC: %f (std: %f)', torch.median(SRCC_all_repeats_train).item(), torch.std(SRCC_all_repeats_train).item())
546
+ logging.info('KRCC: %f (std: %f)', torch.median(KRCC_all_repeats_train).item(), torch.std(KRCC_all_repeats_train).item())
547
+ logging.info('PLCC: %f (std: %f)', torch.median(PLCC_all_repeats_train).item(), torch.std(PLCC_all_repeats_train).item())
548
+ logging.info('RMSE: %f (std: %f)', torch.median(RMSE_all_repeats_train).item(), torch.std(RMSE_all_repeats_train).item())
549
+ logging.info('======================================================')
550
+ logging.info('Average testing results among all repeated 80-20 holdouts:')
551
+ logging.info('SRCC: %f (std: %f)', torch.median(SRCC_all_repeats_test).item(), torch.std(SRCC_all_repeats_test).item())
552
+ logging.info('KRCC: %f (std: %f)', torch.median(KRCC_all_repeats_test).item(), torch.std(KRCC_all_repeats_test).item())
553
+ logging.info('PLCC: %f (std: %f)', torch.median(PLCC_all_repeats_test).item(), torch.std(PLCC_all_repeats_test).item())
554
+ logging.info('RMSE: %f (std: %f)', torch.median(RMSE_all_repeats_test).item(), torch.std(RMSE_all_repeats_test).item())
555
+ logging.info('======================================================')
556
+ logging.info('\n')
557
+
558
+ # find the median model and the index of the median
559
+ print('======================================================')
560
+ if select_criteria == 'byrmse':
561
+ median_metrics = torch.median(RMSE_all_repeats_test).item()
562
+ indices = (RMSE_all_repeats_test == median_metrics).nonzero(as_tuple=True)[0].tolist()
563
+ select_criteria = select_criteria.replace('by', '').upper()
564
+ print(RMSE_all_repeats_test)
565
+ logging.info(f'all {select_criteria}: {RMSE_all_repeats_test}')
566
+ elif select_criteria == 'bykrcc':
567
+ median_metrics = torch.median(KRCC_all_repeats_test).item()
568
+ indices = (KRCC_all_repeats_test == median_metrics).nonzero(as_tuple=True)[0].tolist()
569
+ select_criteria = select_criteria.replace('by', '').upper()
570
+ print(KRCC_all_repeats_test)
571
+ logging.info(f'all {select_criteria}: {KRCC_all_repeats_test}')
572
+
573
+ median_test_vids = [all_repeats_test_vids[i] for i in indices]
574
+ test_vids = [arr.tolist() for arr in median_test_vids] if len(median_test_vids) > 1 else (median_test_vids[0] if median_test_vids else [])
575
+
576
+ # select the model with the first index where the median is located
577
+ # Note: If there are multiple iterations with the same median RMSE, the first index is selected here
578
+ median_model = None
579
+ if len(indices) > 0:
580
+ median_index = indices[0] # select the first index
581
+ median_model = best_model_list[median_index]
582
+ median_model_df_test_pred = all_repeats_df_test_pred[median_index]
583
+
584
+ median_model_df_test_pred.to_csv(pred_score_filename, index=False)
585
+ plot_results(y_test, y_test_pred_logistic, median_model_df_test_pred, model_name, data_name, network_name, select_criteria)
586
+
587
+ print(f'Median Metrics: {median_metrics}')
588
+ print(f'Indices: {indices}')
589
+ # print(f'Test Videos: {test_vids}')
590
+ print(f'Best model: {median_model}')
591
+
592
+ logging.info(f'median test {select_criteria}: {median_metrics}')
593
+ logging.info(f"Indices of median metrics: {indices}")
594
+ # logging.info(f'Best training and test dataset: {test_vids}')
595
+ logging.info(f'Best model predict score: {median_model_df_test_pred}')
596
+ logging.info(f'Best model: {median_model}')
597
+
598
+ # ================================================================================
599
+ # save mats
600
+ scipy.io.savemat(result_file, mdict={'SRCC_train': SRCC_all_repeats_train.numpy(),
601
+ 'KRCC_train': KRCC_all_repeats_train.numpy(),
602
+ 'PLCC_train': PLCC_all_repeats_train.numpy(),
603
+ 'RMSE_train': RMSE_all_repeats_train.numpy(),
604
+ 'SRCC_test': SRCC_all_repeats_test.numpy(),
605
+ 'KRCC_test': KRCC_all_repeats_test.numpy(),
606
+ 'PLCC_test': PLCC_all_repeats_test.numpy(),
607
+ 'RMSE_test': RMSE_all_repeats_test.numpy(),
608
+ f'Median_{select_criteria}': median_metrics,
609
+ 'Test_Videos_list': all_repeats_test_vids,
610
+ 'Test_videos_Median_model': test_vids})
611
+
612
+ # save model
613
+ torch.save(median_model.state_dict(), file_path)
614
+ print(f"Model state_dict saved to {file_path}")
615
+
616
+
617
+ if __name__ == '__main__':
618
+ parser = argparse.ArgumentParser()
619
+ # input parameters
620
+ parser.add_argument('--model_name', type=str, default='Mlp')
621
+ parser.add_argument('--data_name', type=str, default='konvid_1k')
622
+ parser.add_argument('--network_name', type=str, default='camp-vqa')
623
+
624
+ parser.add_argument('--metadata_path', type=str, default='../metadata/')
625
+ parser.add_argument('--feature_path', type=str, default=f'../features/camp-vqa/')
626
+ parser.add_argument('--log_path', type=str, default='../log/')
627
+ parser.add_argument('--save_path', type=str, default='../model/')
628
+ parser.add_argument('--score_path', type=str, default='../log/predict_score/')
629
+ parser.add_argument('--result_path', type=str, default='../log/result/')
630
+ # training parameters
631
+ parser.add_argument('--select_criteria', type=str, default='byrmse')
632
+ parser.add_argument('--n_repeats', type=int, default=21)
633
+ parser.add_argument('--batch_size', type=int, default=256)
634
+ parser.add_argument('--epochs', type=int, default=200)
635
+ parser.add_argument('--hidden_features', type=int, default=256)
636
+ parser.add_argument('--drop_rate', type=float, default=0.1)
637
+ # misc
638
+ parser.add_argument('--loss_type', type=str, default='MAERankLoss')
639
+ parser.add_argument('--optimizer_type', type=str, default='sgd')
640
+ parser.add_argument('--initial_lr', type=float, default=1e-2)
641
+ parser.add_argument('--weight_decay', type=float, default=0.0005)
642
+ parser.add_argument('--patience', type=int, default=5)
643
+ parser.add_argument('--use_swa', type=bool, default=True)
644
+ parser.add_argument('--l1_w', type=float, default=0.6)
645
+ parser.add_argument('--rank_w', type=float, default=1.0)
646
+
647
+ args = parser.parse_args()
648
+ config = vars(args) # args to dict
649
+ print(config)
650
+
651
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
652
+ print(device)
653
+ if device.type == "cuda":
654
+ torch.cuda.set_device(0)
655
+
656
+ main(config)
model_regression_lsvq.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import os
4
+ import pandas as pd
5
+ import numpy as np
6
+ import math
7
+ import scipy.io
8
+ import scipy.stats
9
+ from sklearn.impute import SimpleImputer
10
+ from sklearn.preprocessing import MinMaxScaler
11
+ from sklearn.metrics import mean_squared_error
12
+ from scipy.optimize import curve_fit
13
+ import joblib
14
+
15
+ import seaborn as sns
16
+ import matplotlib.pyplot as plt
17
+ import copy
18
+ import argparse
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.optim as optim
24
+ from torch.optim.lr_scheduler import CosineAnnealingLR
25
+ from torch.optim.swa_utils import AveragedModel, SWALR
26
+ from torch.utils.data import DataLoader, TensorDataset
27
+ from sklearn.model_selection import KFold
28
+ from sklearn.model_selection import train_test_split
29
+
30
+ from data_processing import split_train_test
31
+
32
+ # ignore all warnings
33
+ import warnings
34
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
35
+
36
+
37
+ class Mlp(nn.Module):
38
+ def __init__(self, input_features, hidden_features=256, out_features=1, drop_rate=0.2, act_layer=nn.GELU):
39
+ super().__init__()
40
+ self.fc1 = nn.Linear(input_features, hidden_features)
41
+ self.bn1 = nn.BatchNorm1d(hidden_features)
42
+ self.act1 = act_layer()
43
+ self.drop1 = nn.Dropout(drop_rate)
44
+ self.fc2 = nn.Linear(hidden_features, hidden_features // 2)
45
+ self.act2 = act_layer()
46
+ self.drop2 = nn.Dropout(drop_rate)
47
+ self.fc3 = nn.Linear(hidden_features // 2, out_features)
48
+
49
+ def forward(self, input_feature):
50
+ x = self.fc1(input_feature)
51
+ x = self.bn1(x)
52
+ x = self.act1(x)
53
+ x = self.drop1(x)
54
+ x = self.fc2(x)
55
+ x = self.act2(x)
56
+ x = self.drop2(x)
57
+ output = self.fc3(x)
58
+ return output
59
+
60
+
61
+ class MAEAndRankLoss(nn.Module):
62
+ def __init__(self, l1_w=1.0, rank_w=1.0, margin=0.0, use_margin=False):
63
+ super(MAEAndRankLoss, self).__init__()
64
+ self.l1_w = l1_w
65
+ self.rank_w = rank_w
66
+ self.margin = margin
67
+ self.use_margin = use_margin
68
+
69
+ def forward(self, y_pred, y_true):
70
+ # L1 loss/MAE loss
71
+ l_mae = F.l1_loss(y_pred, y_true, reduction='mean') * self.l1_w
72
+ # Rank loss
73
+ n = y_pred.size(0)
74
+ pred_diff = y_pred.unsqueeze(1) - y_pred.unsqueeze(0)
75
+ true_diff = y_true.unsqueeze(1) - y_true.unsqueeze(0)
76
+
77
+ # e(ytrue_i, ytrue_j)
78
+ masks = torch.sign(true_diff)
79
+
80
+ if self.use_margin and self.margin > 0:
81
+ true_diff = true_diff.abs() - self.margin
82
+ true_diff = F.relu(true_diff)
83
+ masks = true_diff.sign()
84
+
85
+ l_rank = F.relu(true_diff - masks * pred_diff)
86
+ l_rank = l_rank.sum() / (n * (n - 1))
87
+
88
+ loss = l_mae + l_rank * self.rank_w
89
+ return loss
90
+
91
+ def load_data(csv, data, data_name, set_name):
92
+ try:
93
+ df = pd.read_csv(csv, skiprows=[], header=None)
94
+ except Exception as e:
95
+ logging.error(f'Read CSV file error: {e}')
96
+ raise
97
+
98
+ y_data = df.values[1:, 2].astype(float)
99
+ y = torch.tensor(y_data, dtype=torch.float32)
100
+
101
+ if set_name == 'test':
102
+ print(f"Modified y_true: {y}")
103
+ X = data
104
+ return X, y
105
+
106
+ def preprocess_data(X, y):
107
+ X[torch.isnan(X)] = 0
108
+ X[torch.isinf(X)] = 0
109
+
110
+ # MinMaxScaler (use PyTorch implementation)
111
+ X_min = X.min(dim=0, keepdim=True).values
112
+ X_max = X.max(dim=0, keepdim=True).values
113
+ X = (X - X_min) / (X_max - X_min)
114
+ if y is not None:
115
+ y = y.view(-1, 1).squeeze()
116
+ return X, y
117
+
118
+ # define 4-parameter logistic regression
119
+ def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
120
+ logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
121
+ yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
122
+ return yhat
123
+
124
+ def fit_logistic_regression(y_pred, y_true):
125
+ beta = [np.max(y_true), np.min(y_true), np.mean(y_pred), 0.5]
126
+ popt, _ = curve_fit(logistic_func, y_pred, y_true, p0=beta, maxfev=100000000)
127
+ y_pred_logistic = logistic_func(y_pred, *popt)
128
+ return y_pred_logistic, beta, popt
129
+
130
+ def compute_correlation_metrics(y_true, y_pred):
131
+ y_pred_logistic, beta, popt = fit_logistic_regression(y_pred, y_true)
132
+
133
+ plcc = scipy.stats.pearsonr(y_true, y_pred_logistic)[0]
134
+ rmse = np.sqrt(mean_squared_error(y_true, y_pred_logistic))
135
+ srcc = scipy.stats.spearmanr(y_true, y_pred)[0]
136
+
137
+ try:
138
+ krcc = scipy.stats.kendalltau(y_true, y_pred)[0]
139
+ except Exception as e:
140
+ logging.error(f'krcc calculation: {e}')
141
+ krcc = scipy.stats.kendalltau(y_true, y_pred, method='asymptotic')[0]
142
+ return y_pred_logistic, plcc, rmse, srcc, krcc
143
+
144
+ def plot_results(y_test, y_test_pred_logistic, df_pred_score, model_name, data_name, network_name, select_criteria):
145
+ # nonlinear logistic fitted curve / logistic regression
146
+ if isinstance(y_test, torch.Tensor):
147
+ mos1 = y_test.numpy()
148
+ y1 = y_test_pred_logistic
149
+
150
+ try:
151
+ beta = [np.max(mos1), np.min(mos1), np.mean(y1), 0.5]
152
+ popt, pcov = curve_fit(logistic_func, y1, mos1, p0=beta, maxfev=100000000)
153
+ sigma = np.sqrt(np.diag(pcov))
154
+ except:
155
+ raise Exception('Fitting logistic function time-out!!')
156
+ x_values1 = np.linspace(np.min(y1), np.max(y1), len(y1))
157
+ plt.plot(x_values1, logistic_func(x_values1, *popt), '-', color='#c72e29', label='Fitted f(x)')
158
+
159
+ fig1 = sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df_pred_score, markers='o', color='steelblue', label=network_name)
160
+ plt.legend(loc='upper left')
161
+ if data_name == 'live_vqc' or data_name == 'live_qualcomm' or data_name == 'cvd_2014' or data_name == 'lsvq_train':
162
+ plt.ylim(0, 100)
163
+ plt.xlim(0, 100)
164
+ else:
165
+ plt.ylim(1, 5)
166
+ plt.xlim(1, 5)
167
+ plt.title(f"Algorithm {network_name} with {model_name} on dataset {data_name}", fontsize=10)
168
+ plt.xlabel('Predicted Score')
169
+ plt.ylabel('MOS')
170
+ reg_fig1 = fig1.get_figure()
171
+
172
+ fig_path = f'../figs/{data_name}/'
173
+ os.makedirs(fig_path, exist_ok=True)
174
+ reg_fig1.savefig(fig_path + f"{network_name}_{model_name}_{data_name}_{select_criteria}_kfold.png", dpi=300)
175
+ plt.clf()
176
+ plt.close()
177
+
178
+ def plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, test_vids, i):
179
+ plt.figure(figsize=(10, 6))
180
+
181
+ plt.plot(avg_train_losses, label='Average Training Loss')
182
+ plt.plot(avg_val_losses, label='Average Validation Loss')
183
+
184
+ plt.xlabel('Epoch')
185
+ plt.ylabel('Loss')
186
+ plt.title(f'Average Training and Validation Loss Across Folds - {network_name} with {model_name} (test_vids: {test_vids})', fontsize=10)
187
+
188
+ plt.legend()
189
+ fig_par_path = f'../log/result/{data_name}/'
190
+ os.makedirs(fig_par_path, exist_ok=True)
191
+ plt.savefig(f'{fig_par_path}/{network_name}_Average_Training_Loss_test{i}.png', dpi=50)
192
+ plt.clf()
193
+ plt.close()
194
+
195
+ def configure_logging(log_path, model_name, data_name, network_name, select_criteria):
196
+ log_file_name = os.path.join(log_path, f"{data_name}_{network_name}_{model_name}_{select_criteria}_kfold.log")
197
+ logging.basicConfig(filename=log_file_name, filemode='w', level=logging.DEBUG, format='%(levelname)s - %(message)s')
198
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
199
+ logging.info(f"Evaluating algorithm {network_name} with {model_name} on dataset {data_name}")
200
+ logging.info(f"torch cuda: {torch.cuda.is_available()}")
201
+
202
+ def load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features):
203
+ if data_name == 'lsvq_train':
204
+ train_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
205
+ test_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
206
+ X_train, y_train = load_data(train_csv, train_features, data_name, 'train')
207
+ X_test, y_test = load_data(test_csv, test_features, data_name, 'test')
208
+
209
+ else:
210
+ train_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_train.csv')
211
+ test_csv = os.path.join(metadata_path, f'mos_files/{data_name}_MOS_test.csv')
212
+ train_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name}_train_features.pt')
213
+ test_data = torch.load(f'{feature_path}split_train_test/{network_name}_{data_name}_test_features.pt')
214
+ X_train, y_train = load_data(train_csv, train_data, data_name, 'train')
215
+ X_test, y_test = load_data(test_csv, test_data, data_name, 'test')
216
+
217
+ # standard min-max normalization of training features
218
+ X_train, y_train = preprocess_data(X_train, y_train)
219
+ X_test, y_test = preprocess_data(X_test, y_test)
220
+
221
+ return X_train, y_train, X_test, y_test
222
+
223
+ def train_one_epoch(model, train_loader, criterion, optimizer, device):
224
+ """Train the model for one epoch"""
225
+ model.train()
226
+ train_loss = 0.0
227
+ for inputs, targets in train_loader:
228
+ inputs, targets = inputs.to(device), targets.to(device)
229
+
230
+ optimizer.zero_grad()
231
+ outputs = model(inputs)
232
+ loss = criterion(outputs, targets.view(-1, 1))
233
+ loss.backward()
234
+ optimizer.step()
235
+ train_loss += loss.item() * inputs.size(0)
236
+ train_loss /= len(train_loader.dataset)
237
+ return train_loss
238
+
239
+ def evaluate(model, val_loader, criterion, device):
240
+ """Evaluate model performance on validation sets"""
241
+ model.eval()
242
+ val_loss = 0.0
243
+ y_val_pred = []
244
+ y_val_true = []
245
+ with torch.no_grad():
246
+ for inputs, targets in val_loader:
247
+ inputs, targets = inputs.to(device), targets.to(device)
248
+
249
+ outputs = model(inputs)
250
+ y_val_pred.append(outputs)
251
+ y_val_true.append(targets)
252
+ loss = criterion(outputs, targets.view(-1, 1))
253
+ val_loss += loss.item() * inputs.size(0)
254
+
255
+ val_loss /= len(val_loader.dataset)
256
+ y_val_pred = torch.cat(y_val_pred, dim=0)
257
+ y_val_true = torch.cat(y_val_true, dim=0)
258
+ return val_loss, y_val_pred, y_val_true
259
+
260
+ def update_best_model(select_criteria, best_metric, current_val, model):
261
+ is_better = False
262
+ if select_criteria == 'byrmse' and current_val < best_metric:
263
+ is_better = True
264
+ elif select_criteria == 'bykrcc' and current_val > best_metric:
265
+ is_better = True
266
+
267
+ if is_better:
268
+ return current_val, copy.deepcopy(model), is_better
269
+ return best_metric, model, is_better
270
+
271
+ def train_and_evaluate(X_train, y_train, config):
272
+ # parameters
273
+ n_repeats = config['n_repeats']
274
+ n_splits = config['n_splits']
275
+ batch_size = config['batch_size']
276
+ epochs = config['epochs']
277
+ hidden_features = config['hidden_features']
278
+ drop_rate = config['drop_rate']
279
+ loss_type = config['loss_type']
280
+ optimizer_type = config['optimizer_type']
281
+ select_criteria = config['select_criteria']
282
+ initial_lr = config['initial_lr']
283
+ weight_decay = config['weight_decay']
284
+ patience = config['patience']
285
+ l1_w = config['l1_w']
286
+ rank_w = config['rank_w']
287
+ use_swa = config.get('use_swa', False)
288
+ logging.info(f'Parameters - Number of repeats for 80-20 hold out test: {n_repeats}, Number of splits for kfold: {n_splits}, Batch size: {batch_size}, Number of epochs: {epochs}')
289
+ logging.info(f'Network Parameters - hidden_features: {hidden_features}, drop_rate: {drop_rate}, patience: {patience}')
290
+ logging.info(f'Optimizer Parameters - loss_type: {loss_type}, optimizer_type: {optimizer_type}, initial_lr: {initial_lr}, weight_decay: {weight_decay}, use_swa: {use_swa}')
291
+ logging.info(f'MAEAndRankLoss - l1_w: {l1_w}, rank_w: {rank_w}')
292
+
293
+ kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
294
+ best_model = None
295
+ best_metric = float('inf') if select_criteria == 'byrmse' else float('-inf')
296
+
297
+ # loss for every fold
298
+ all_train_losses = []
299
+ all_val_losses = []
300
+ for fold, (train_idx, val_idx) in enumerate(kf.split(X_train)):
301
+ print(f"Fold {fold + 1}/{n_splits}")
302
+
303
+ X_train_fold, X_val_fold = X_train[train_idx], X_train[val_idx]
304
+ y_train_fold, y_val_fold = y_train[train_idx], y_train[val_idx]
305
+
306
+ # initialisation of model, loss function, optimiser
307
+ model = Mlp(input_features=X_train_fold.shape[1], hidden_features=hidden_features, drop_rate=drop_rate)
308
+ model = model.to(device) # to gpu
309
+
310
+ if loss_type == 'MAERankLoss':
311
+ criterion = MAEAndRankLoss()
312
+ criterion.l1_w = l1_w
313
+ criterion.rank_w = rank_w
314
+ else:
315
+ criterion = nn.MSELoss()
316
+
317
+ if optimizer_type == 'sgd':
318
+ optimizer = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9, weight_decay=weight_decay)
319
+ scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-5)# initial eta_nim=1e-5
320
+ else:
321
+ optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=weight_decay) # L2 Regularisation initial: 0.01, 1e-5
322
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) # step_size=10, gamma=0.1: every 10 epochs lr*0.1
323
+ if use_swa:
324
+ swa_model = AveragedModel(model).to(device)
325
+ swa_scheduler = SWALR(optimizer, swa_lr=initial_lr, anneal_strategy='cos')
326
+
327
+ # dataset loader
328
+ train_dataset = TensorDataset(X_train_fold, y_train_fold)
329
+ val_dataset = TensorDataset(X_val_fold, y_val_fold)
330
+ train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
331
+ val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
332
+
333
+ train_losses, val_losses = [], []
334
+
335
+ # early stopping parameters
336
+ best_val_loss = float('inf')
337
+ epochs_no_improve = 0
338
+ early_stop_active = False
339
+ swa_start = int(epochs * 0.7) if use_swa else epochs # SWA starts after 70% of total epochs, only set SWA start if SWA is used
340
+
341
+ for epoch in range(epochs):
342
+ train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
343
+ train_losses.append(train_loss)
344
+ scheduler.step() # update learning rate
345
+ if use_swa and epoch >= swa_start:
346
+ swa_model.update_parameters(model)
347
+ swa_scheduler.step()
348
+ early_stop_active = True
349
+ print(f"Current learning rate with SWA: {swa_scheduler.get_last_lr()}")
350
+
351
+ lr = optimizer.param_groups[0]['lr']
352
+ print('Epoch %d: Learning rate: %f' % (epoch + 1, lr))
353
+
354
+ # decide which model to evaluate: SWA model or regular model
355
+ current_model = swa_model if use_swa and epoch >= swa_start else model
356
+ current_model.eval()
357
+ val_loss, y_val_pred, y_val_true = evaluate(current_model, val_loader, criterion, device)
358
+ val_losses.append(val_loss)
359
+ print(f"Epoch {epoch + 1}, Fold {fold + 1}, Training Loss: {train_loss}, Validation Loss: {val_loss}")
360
+
361
+ y_val_pred = torch.cat([pred for pred in y_val_pred])
362
+ _, _, rmse_val, _, krcc_val = compute_correlation_metrics(y_val_fold.cpu().numpy(), y_val_pred.cpu().numpy())
363
+ current_metric = rmse_val if select_criteria == 'byrmse' else krcc_val
364
+ best_metric, best_model, is_better = update_best_model(select_criteria, best_metric, current_metric, current_model)
365
+ if is_better:
366
+ logging.info(f"Epoch {epoch + 1}, Fold {fold + 1}:")
367
+ y_val_pred_logistic_tmp, plcc_valid_tmp, rmse_valid_tmp, srcc_valid_tmp, krcc_valid_tmp = compute_correlation_metrics(y_val_fold.cpu().numpy(), y_val_pred.cpu().numpy())
368
+ logging.info(f'Validation set - Evaluation Results - SRCC: {srcc_valid_tmp}, KRCC: {krcc_valid_tmp}, PLCC: {plcc_valid_tmp}, RMSE: {rmse_valid_tmp}')
369
+
370
+ X_train_fold_tensor = X_train_fold
371
+ y_tra_pred_tmp = best_model(X_train_fold_tensor).detach().cpu().squeeze()
372
+ y_tra_pred_logistic_tmp, plcc_train_tmp, rmse_train_tmp, srcc_train_tmp, krcc_train_tmp = compute_correlation_metrics(y_train_fold.cpu().numpy(), y_tra_pred_tmp.cpu().numpy())
373
+ logging.info(f'Train set - Evaluation Results - SRCC: {srcc_train_tmp}, KRCC: {krcc_train_tmp}, PLCC: {plcc_train_tmp}, RMSE: {rmse_train_tmp}')
374
+
375
+ # check for loss improvement
376
+ if early_stop_active:
377
+ if val_loss < best_val_loss:
378
+ best_val_loss = val_loss
379
+ # save the best model if validation loss improves
380
+ best_model = copy.deepcopy(model)
381
+ epochs_no_improve = 0
382
+ else:
383
+ epochs_no_improve += 1
384
+ if epochs_no_improve >= patience:
385
+ # epochs to wait for improvement before stopping
386
+ print(f"Early stopping triggered after {epoch + 1} epochs.")
387
+ break
388
+
389
+ # saving SWA models and updating BN statistics
390
+ if use_swa:
391
+ train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
392
+ best_model = best_model.to(device)
393
+ best_model.eval()
394
+ torch.optim.swa_utils.update_bn(train_loader, best_model)
395
+ # swa_model_path = os.path.join('save_swa_path='../model/', f'model_swa_fold{fold}.pth')
396
+ # torch.save(swa_model.state_dict(), swa_model_path)
397
+ # logging.info(f'SWA model saved at {swa_model_path}')
398
+
399
+ all_train_losses.append(train_losses)
400
+ all_val_losses.append(val_losses)
401
+ max_length = max(len(x) for x in all_train_losses)
402
+ all_train_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_train_losses]
403
+ max_length = max(len(x) for x in all_val_losses)
404
+ all_val_losses = [x + [x[-1]] * (max_length - len(x)) for x in all_val_losses]
405
+
406
+ return best_model, all_train_losses, all_val_losses
407
+
408
+ def collate_to_device(batch, device):
409
+ data, targets = zip(*batch)
410
+ return torch.stack(data).to(device), torch.stack(targets).to(device)
411
+
412
+ def model_test(best_model, X, y, device):
413
+ test_dataset = TensorDataset(X, y)
414
+ test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
415
+
416
+ best_model.eval()
417
+ y_pred = []
418
+ with torch.no_grad():
419
+ for inputs, _ in test_loader:
420
+ inputs = inputs.to(device)
421
+
422
+ outputs = best_model(inputs)
423
+ y_pred.extend(outputs.view(-1).tolist())
424
+
425
+ return y_pred
426
+
427
+ def main(config):
428
+ model_name = config['model_name']
429
+ data_name = config['data_name']
430
+ network_name = config['network_name']
431
+
432
+ metadata_path = config['metadata_path']
433
+ feature_path = config['feature_path']
434
+ log_path = config['log_path']
435
+ save_path = config['save_path']
436
+ score_path = config['score_path']
437
+ result_path = config['result_path']
438
+
439
+ # parameters
440
+ select_criteria = config['select_criteria']
441
+ n_repeats = config['n_repeats']
442
+
443
+ # logging and result
444
+ os.makedirs(log_path, exist_ok=True)
445
+ os.makedirs(save_path, exist_ok=True)
446
+ os.makedirs(score_path, exist_ok=True)
447
+ os.makedirs(result_path, exist_ok=True)
448
+ result_file = f'{result_path}{data_name}_{network_name}_{model_name}_{select_criteria}_kfold.mat'
449
+ pred_score_filename = os.path.join(score_path, f"{data_name}_{network_name}_{model_name}_Predicted_Score_{select_criteria}_kfold.csv")
450
+ file_path = os.path.join(save_path, f"{data_name}_{network_name}_{model_name}_{select_criteria}_trained_model_kfold.pth")
451
+ configure_logging(log_path, model_name, data_name, network_name, select_criteria)
452
+
453
+ '''======================== Main Body ==========================='''
454
+ PLCC_all_repeats_test = []
455
+ SRCC_all_repeats_test = []
456
+ KRCC_all_repeats_test = []
457
+ RMSE_all_repeats_test = []
458
+ PLCC_all_repeats_train = []
459
+ SRCC_all_repeats_train = []
460
+ KRCC_all_repeats_train = []
461
+ RMSE_all_repeats_train = []
462
+ all_repeats_test_vids = []
463
+ all_repeats_df_test_pred = []
464
+ best_model_list = []
465
+
466
+ for i in range(1, n_repeats + 1):
467
+ print(f"{i}th repeated 80-20 hold out test")
468
+ logging.info(f"{i}th repeated 80-20 hold out test")
469
+ t0 = time.time()
470
+
471
+ # train test split
472
+ test_size = 0.2
473
+ random_state = math.ceil(8.8 * i)
474
+ # NR: original
475
+ if data_name == 'lsvq_train':
476
+ test_data_name = 'lsvq_test' #lsvq_test, lsvq_test_1080p
477
+ train_features, test_features, test_vids = split_train_test.process_lsvq(data_name, test_data_name, metadata_path, feature_path, network_name)
478
+ else:
479
+ _, _, test_vids = split_train_test.process_other(data_name, test_size, random_state, metadata_path, feature_path, network_name)
480
+
481
+ '''======================== read files =============================== '''
482
+ if data_name == 'lsvq_train':
483
+ X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, train_features, test_features)
484
+ else:
485
+ X_train, y_train, X_test, y_test = load_and_preprocess_data(metadata_path, feature_path, data_name, network_name, None, None)
486
+
487
+ '''======================== regression model =============================== '''
488
+ best_model, all_train_losses, all_val_losses = train_and_evaluate(X_train, y_train, config)
489
+
490
+ # average loss plots
491
+ avg_train_losses = np.mean(all_train_losses, axis=0)
492
+ avg_val_losses = np.mean(all_val_losses, axis=0)
493
+ test_vids = test_vids.tolist()
494
+ plot_and_save_losses(avg_train_losses, avg_val_losses, model_name, data_name, network_name, len(test_vids), i)
495
+
496
+ # predict best model on the train dataset
497
+ y_train_pred = model_test(best_model, X_train, y_train, device)
498
+ y_train_pred = torch.tensor(list(y_train_pred), dtype=torch.float32)
499
+ y_train_pred_logistic, plcc_train, rmse_train, srcc_train, krcc_train = compute_correlation_metrics(y_train.cpu().numpy(), y_train_pred.cpu().numpy())
500
+
501
+ # test best model on the test dataset
502
+ y_test_pred = model_test(best_model, X_test, y_test, device)
503
+ y_test_pred = torch.tensor(list(y_test_pred), dtype=torch.float32)
504
+ y_test_pred_logistic, plcc_test, rmse_test, srcc_test, krcc_test = compute_correlation_metrics(y_test.cpu().numpy(), y_test_pred.cpu().numpy())
505
+
506
+ # save the predict score results
507
+ test_pred_score = {'MOS': y_test, 'y_test_pred': y_test_pred, 'y_test_pred_logistic': y_test_pred_logistic}
508
+ df_test_pred = pd.DataFrame(test_pred_score)
509
+
510
+ # logging logistic predicted scores
511
+ logging.info("============================================================================================================")
512
+ SRCC_all_repeats_test.append(srcc_test)
513
+ KRCC_all_repeats_test.append(krcc_test)
514
+ PLCC_all_repeats_test.append(plcc_test)
515
+ RMSE_all_repeats_test.append(rmse_test)
516
+ SRCC_all_repeats_train.append(srcc_train)
517
+ KRCC_all_repeats_train.append(krcc_train)
518
+ PLCC_all_repeats_train.append(plcc_train)
519
+ RMSE_all_repeats_train.append(rmse_train)
520
+ all_repeats_test_vids.append(test_vids)
521
+ all_repeats_df_test_pred.append(df_test_pred)
522
+ best_model_list.append(copy.deepcopy(best_model))
523
+
524
+ # logging.info results for each iteration
525
+ logging.info('Best results in Mlp model within one split')
526
+ logging.info(f'MODEL: {best_model}')
527
+ logging.info('======================================================')
528
+ logging.info(f'Train set - Evaluation Results')
529
+ logging.info(f'SRCC_train: {srcc_train}')
530
+ logging.info(f'KRCC_train: {krcc_train}')
531
+ logging.info(f'PLCC_train: {plcc_train}')
532
+ logging.info(f'RMSE_train: {rmse_train}')
533
+ logging.info('======================================================')
534
+ logging.info(f'Test set - Evaluation Results')
535
+ logging.info(f'SRCC_test: {srcc_test}')
536
+ logging.info(f'KRCC_test: {krcc_test}')
537
+ logging.info(f'PLCC_test: {plcc_test}')
538
+ logging.info(f'RMSE_test: {rmse_test}')
539
+ logging.info('======================================================')
540
+ logging.info(' -- {} seconds elapsed...\n\n'.format(time.time() - t0))
541
+
542
+ logging.info('')
543
+ SRCC_all_repeats_test = torch.tensor(SRCC_all_repeats_test, dtype=torch.float32)
544
+ KRCC_all_repeats_test = torch.tensor(KRCC_all_repeats_test, dtype=torch.float32)
545
+ PLCC_all_repeats_test = torch.tensor(PLCC_all_repeats_test, dtype=torch.float32)
546
+ RMSE_all_repeats_test = torch.tensor(RMSE_all_repeats_test, dtype=torch.float32)
547
+ SRCC_all_repeats_train = torch.tensor(SRCC_all_repeats_train, dtype=torch.float32)
548
+ KRCC_all_repeats_train = torch.tensor(KRCC_all_repeats_train, dtype=torch.float32)
549
+ PLCC_all_repeats_train = torch.tensor(PLCC_all_repeats_train, dtype=torch.float32)
550
+ RMSE_all_repeats_train = torch.tensor(RMSE_all_repeats_train, dtype=torch.float32)
551
+
552
+ logging.info('======================================================')
553
+ logging.info('Average training results among all repeated 80-20 holdouts:')
554
+ logging.info('SRCC: %f (std: %f)', torch.median(SRCC_all_repeats_train).item(), torch.std(SRCC_all_repeats_train).item())
555
+ logging.info('KRCC: %f (std: %f)', torch.median(KRCC_all_repeats_train).item(), torch.std(KRCC_all_repeats_train).item())
556
+ logging.info('PLCC: %f (std: %f)', torch.median(PLCC_all_repeats_train).item(), torch.std(PLCC_all_repeats_train).item())
557
+ logging.info('RMSE: %f (std: %f)', torch.median(RMSE_all_repeats_train).item(), torch.std(RMSE_all_repeats_train).item())
558
+ logging.info('======================================================')
559
+ logging.info('Average testing results among all repeated 80-20 holdouts:')
560
+ logging.info('SRCC: %f (std: %f)', torch.median(SRCC_all_repeats_test).item(), torch.std(SRCC_all_repeats_test).item())
561
+ logging.info('KRCC: %f (std: %f)', torch.median(KRCC_all_repeats_test).item(), torch.std(KRCC_all_repeats_test).item())
562
+ logging.info('PLCC: %f (std: %f)', torch.median(PLCC_all_repeats_test).item(), torch.std(PLCC_all_repeats_test).item())
563
+ logging.info('RMSE: %f (std: %f)', torch.median(RMSE_all_repeats_test).item(), torch.std(RMSE_all_repeats_test).item())
564
+ logging.info('======================================================')
565
+ logging.info('\n')
566
+
567
+ # find the median model and the index of the median
568
+ print('======================================================')
569
+ if select_criteria == 'byrmse':
570
+ median_metrics = torch.median(RMSE_all_repeats_test).item()
571
+ indices = (RMSE_all_repeats_test == median_metrics).nonzero(as_tuple=True)[0].tolist()
572
+ select_criteria = select_criteria.replace('by', '').upper()
573
+ print(RMSE_all_repeats_test)
574
+ logging.info(f'all {select_criteria}: {RMSE_all_repeats_test}')
575
+ elif select_criteria == 'bykrcc':
576
+ median_metrics = torch.median(KRCC_all_repeats_test).item()
577
+ indices = (KRCC_all_repeats_test == median_metrics).nonzero(as_tuple=True)[0].tolist()
578
+ select_criteria = select_criteria.replace('by', '').upper()
579
+ print(KRCC_all_repeats_test)
580
+ logging.info(f'all {select_criteria}: {KRCC_all_repeats_test}')
581
+
582
+ median_test_vids = [all_repeats_test_vids[i] for i in indices]
583
+ test_vids = [arr.tolist() for arr in median_test_vids] if len(median_test_vids) > 1 else (median_test_vids[0] if median_test_vids else [])
584
+
585
+ # select the model with the first index where the median is located
586
+ # Note: If there are multiple iterations with the same median RMSE, the first index is selected here
587
+ median_model = None
588
+ if len(indices) > 0:
589
+ median_index = indices[0] # select the first index
590
+ median_model = best_model_list[median_index]
591
+ median_model_df_test_pred = all_repeats_df_test_pred[median_index]
592
+
593
+ median_model_df_test_pred.to_csv(pred_score_filename, index=False)
594
+ plot_results(y_test, y_test_pred_logistic, median_model_df_test_pred, model_name, data_name, network_name, select_criteria)
595
+
596
+ print(f'Median Metrics: {median_metrics}')
597
+ print(f'Indices: {indices}')
598
+ # print(f'Test Videos: {test_vids}')
599
+ print(f'Best model: {median_model}')
600
+
601
+ logging.info(f'median test {select_criteria}: {median_metrics}')
602
+ logging.info(f"Indices of median metrics: {indices}")
603
+ # logging.info(f'Best training and test dataset: {test_vids}')
604
+ logging.info(f'Best model predict score: {median_model_df_test_pred}')
605
+ logging.info(f'Best model: {median_model}')
606
+
607
+ # ================================================================================
608
+ # save mats
609
+ scipy.io.savemat(result_file, mdict={'SRCC_train': SRCC_all_repeats_train.numpy(),
610
+ 'KRCC_train': KRCC_all_repeats_train.numpy(),
611
+ 'PLCC_train': PLCC_all_repeats_train.numpy(),
612
+ 'RMSE_train': RMSE_all_repeats_train.numpy(),
613
+ 'SRCC_test': SRCC_all_repeats_test.numpy(),
614
+ 'KRCC_test': KRCC_all_repeats_test.numpy(),
615
+ 'PLCC_test': PLCC_all_repeats_test.numpy(),
616
+ 'RMSE_test': RMSE_all_repeats_test.numpy(),
617
+ f'Median_{select_criteria}': median_metrics,
618
+ 'Test_Videos_list': all_repeats_test_vids,
619
+ 'Test_videos_Median_model': test_vids})
620
+
621
+ # save model
622
+ torch.save(median_model.state_dict(), file_path)
623
+ print(f"Model state_dict saved to {file_path}")
624
+
625
+
626
+ if __name__ == '__main__':
627
+ parser = argparse.ArgumentParser()
628
+ # input parameters
629
+ parser.add_argument('--model_name', type=str, default='Mlp')
630
+ parser.add_argument('--data_name', type=str, default='lsvq_train')
631
+ parser.add_argument('--network_name', type=str, default='camp-vqa')
632
+
633
+ parser.add_argument('--metadata_path', type=str, default='../metadata/')
634
+ parser.add_argument('--feature_path', type=str, default=f'../features/camp-vqa/')
635
+ parser.add_argument('--log_path', type=str, default='../log/')
636
+ parser.add_argument('--save_path', type=str, default='../model/')
637
+ parser.add_argument('--score_path', type=str, default='../log/predict_score/')
638
+ parser.add_argument('--result_path', type=str, default='../log/result/')
639
+ # training parameters
640
+ parser.add_argument('--select_criteria', type=str, default='byrmse')
641
+ parser.add_argument('--n_repeats', type=int, default=21)
642
+ parser.add_argument('--n_splits', type=int, default=10)
643
+ parser.add_argument('--batch_size', type=int, default=256)
644
+ parser.add_argument('--epochs', type=int, default=50)
645
+ parser.add_argument('--hidden_features', type=int, default=256)
646
+ parser.add_argument('--drop_rate', type=float, default=0.1)
647
+ # misc
648
+ parser.add_argument('--loss_type', type=str, default='MAERankLoss')
649
+ parser.add_argument('--optimizer_type', type=str, default='sgd')
650
+ parser.add_argument('--initial_lr', type=float, default=1e-1)
651
+ parser.add_argument('--weight_decay', type=float, default=0.005)
652
+ parser.add_argument('--patience', type=int, default=5)
653
+ parser.add_argument('--use_swa', type=bool, default=True)
654
+ parser.add_argument('--l1_w', type=float, default=0.6)
655
+ parser.add_argument('--rank_w', type=float, default=1.0)
656
+
657
+ args = parser.parse_args()
658
+ config = vars(args) # args to dict
659
+ print(config)
660
+
661
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
662
+ print(device)
663
+ if device.type == "cuda":
664
+ torch.cuda.set_device(0)
665
+
666
+ main(config)
requirements.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.6.0
2
+ av==14.3.0
3
+ certifi==2025.1.31
4
+ charset-normalizer==3.4.1
5
+ clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
6
+ ffmpeg==1.4
7
+ filelock==3.18.0
8
+ fsspec==2025.3.2
9
+ ftfy==6.3.1
10
+ fvcore==0.1.5.post20221221
11
+ huggingface-hub==0.30.1
12
+ idna==3.10
13
+ iopath==0.1.10
14
+ Jinja2==3.1.6
15
+ joblib==1.4.2
16
+ MarkupSafe==3.0.2
17
+ matplotlib==3.10.0
18
+ mpmath==1.3.0
19
+ networkx==3.4.2
20
+ nvidia-cublas-cu12==12.4.5.8
21
+ nvidia-cuda-cupti-cu12==12.4.127
22
+ nvidia-cuda-nvrtc-cu12==12.4.127
23
+ nvidia-cuda-runtime-cu12==12.4.127
24
+ nvidia-cudnn-cu12==9.1.0.70
25
+ nvidia-cufft-cu12==11.2.1.3
26
+ nvidia-curand-cu12==10.3.5.147
27
+ nvidia-cusolver-cu12==11.6.1.9
28
+ nvidia-cusparse-cu12==12.3.1.170
29
+ nvidia-cusparselt-cu12==0.6.2
30
+ nvidia-nccl-cu12==2.21.5
31
+ nvidia-nvjitlink-cu12==12.4.127
32
+ nvidia-nvtx-cu12==12.4.127
33
+ opencv-python==4.11.0.86
34
+ parameterized==0.9.0
35
+ peft==0.15.2
36
+ portalocker==3.1.1
37
+ psutil==7.0.0
38
+ PyQt6==6.7.1
39
+ pytorchvideo==0.1.5
40
+ PyYAML==6.0.2
41
+ regex==2024.11.6
42
+ requests==2.32.3
43
+ safetensors==0.5.3
44
+ scikit-learn==1.6.1
45
+ scipy==1.15.2
46
+ seaborn==0.13.2
47
+ sentencepiece==0.2.0
48
+ sympy==1.13.1
49
+ tabulate==0.9.0
50
+ termcolor==3.1.0
51
+ threadpoolctl==3.6.0
52
+ timm==1.0.15
53
+ tokenizers==0.21.1
54
+ torch==2.6.0
55
+ torchvision==0.21.0
56
+ tqdm==4.67.1
57
+ transformers==4.50.3
58
+ triton==3.2.0
59
+ typing_extensions==4.13.0
60
+ urllib3==2.3.0
61
+ wcwidth==0.2.13
62
+ yacs==0.1.8