CAMP-VQA / app.py
xinyiW915's picture
Upload 6 files
7509a87 verified
raw
history blame
7.32 kB
from spaces import GPU
import gradio as gr
import torch
import os
import pandas as pd
from types import SimpleNamespace
import clip
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from extractor.extract_frag import VideoDataset_feature
from extractor.extract_slowfast_clip import SlowFast
from extractor.extract_swint_clip import SwinT
from demo_test import get_transform, load_prompts, get_video_metadata, load_model, evaluate_video_quality
model_cache = {}
@GPU
def run_camp_vqa(video_path, intra_cross_experiment, is_finetune, train_data_name, test_data_name, network_name):
if not os.path.exists(video_path):
return "❌ No video uploaded or the uploaded file has expired. Please upload again."
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
print("Current device:", torch.cuda.current_device())
else:
print("Running on CPU")
config = SimpleNamespace(**{
'model_name': 'Mlp',
'select_criteria': 'byrmse',
'intra_cross_experiment': intra_cross_experiment,
'is_finetune': is_finetune,
'save_model_path': 'model/',
'prompt_path': './config/prompts.json',
'train_data_name': train_data_name,
'test_data_name': test_data_name,
'test_video_path': video_path,
'prediction_mode': 50,
'network_name': network_name,
'num_workers': 2,
'resize': 224,
'patch_size': 16,
'target_size': 224,
})
print(f"Test video path: {config.test_video_path}")
# test demo video
resize_transform = get_transform(config.resize)
top_n = int(config.target_size / config.patch_size) ** 2
width, height, bitrate, bitdepth, framerate = get_video_metadata(config.test_video_path)
data = {'vid': [os.path.splitext(os.path.basename(config.test_video_path))[0]],
'test_data_name': [config.test_data_name],
'test_video_path': [config.test_video_path],
'prediction_mode': [config.prediction_mode],
'width': [width], 'height': [height], 'bitrate': [bitrate], 'bitdepth': [bitdepth], 'framerate': [framerate]}
videos_dir = os.path.dirname(config.test_video_path)
test_df = pd.DataFrame(data)
print(test_df.T)
print(f"Experiment Setting: {config.intra_cross_experiment}, {config.train_data_name} -> {config.test_data_name}")
if config.intra_cross_experiment == 'cross':
if config.train_data_name == 'lsvq_train':
print(f"Fine-tune: {config.is_finetune}")
# load models to device
global model_cache
if not model_cache:
print("Loading models into cache (first time)...")
model_cache["slowfast"] = SlowFast().to(device)
model_cache["swint"] = SwinT(model_name='swin_large_patch4_window7_224', global_pool='avg', pretrained=True).to(device)
model_cache["clip"], model_cache["clip_preprocess"] = clip.load("ViT-B/32", device=device)
model_cache["blip_processor"] = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl", use_fast=True)
model_cache["blip_model"] = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
print("Model cache initialized.")
# get model from cache
model_slowfast = model_cache["slowfast"]
model_swint = model_cache["swint"]
clip_model, clip_preprocess = model_cache["clip"], model_cache["clip_preprocess"]
blip_processor = model_cache["blip_processor"]
blip_model = model_cache["blip_model"]
input_features = 13056
if config.intra_cross_experiment == 'intra':
if config.train_data_name == 'lsvq_train':
from model_regression_lsvq import Mlp, preprocess_data
else:
from model_regression import Mlp, preprocess_data
elif config.intra_cross_experiment == 'cross':
from model_regression_lsvq import Mlp, preprocess_data
model_mlp = load_model(config, device, Mlp, input_features)
prompts = load_prompts(config.prompt_path)
dataset = VideoDataset_feature(test_df, videos_dir, config.test_data_name, resize_transform, config.resize, config.patch_size, config.target_size, top_n)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=min(config.num_workers, os.cpu_count() or 1), pin_memory=(device.type == "cuda")
)
try:
score = evaluate_video_quality(
preprocess_data,
data_loader,
model_slowfast,
model_swint,
clip_model,
clip_preprocess,
blip_processor,
blip_model,
prompts,
model_mlp,
device
)
return f"**Predicted Perceptual Quality Score:** {score:.4f} / 100"
except Exception as e:
return f"❌ Error: {str(e)}"
finally:
if "gradio" in video_path and os.path.exists(video_path):
os.remove(video_path)
def toggle_dataset_visibility(is_finetune):
return gr.update(visible=is_finetune)
with gr.Blocks() as demo:
gr.Markdown("# 📹 CAMP-VQA Online Demo")
gr.Markdown(
"Upload a short video and get its perceptual quality score predicted by CAMP-VQA."
"You can try our test video"
"<a href='https://huggingface.co/spaces/xinyiW915/CAMP-VQA/blob/main/ugc_original_videos/0_16_07_500001604801190-yase.mp4' target='_blank'>demo video</a>. "
"<br><br>"
# "⚙️ This demo is currently running on <strong>Hugging Face CPU Basic</strong>: 2 vCPU • 16 GB RAM."
"⚙️ This demo is currently running on <strong>Hugging Face ZeroGPU Space</strong>: Dynamic resources (NVIDIA A100)."
)
with gr.Row():
with gr.Column(scale=2):
video_input = gr.Video(label="Upload a Video (e.g. .mp4)")
intra_cross_experiment = gr.Dropdown(
label="Intra or Cross experiment",
choices=["intra", "cross"],
value="cross"
)
is_finetune_checkbox = gr.Checkbox(label="Use Finetuning?", value=False)
train_dataset = gr.Dropdown(
label="Train Dataset",
choices=["lsvq_train", "cvd_2014", "konvid_1k", "live_vqc", "youtube_ugc", "finevd", "live_yt_gaming", "kvq"],
value="lsvq_train"
)
test_dataset = gr.Dropdown(
label="Test Dataset for Finetuning",
choices=["lsvq_test", "lsvq_test_1080p", "cvd_2014", "konvid_1k", "live_vqc", "youtube_ugc", "finevd", "live_yt_gaming", "kvq"],
value="finevd"
)
model_dropdown = gr.Dropdown(
label="Our Models",
choices=["camp-vqa"],
value="camp-vqa"
)
run_button = gr.Button("Run Prediction")
with gr.Column(scale=1):
output_box = gr.Textbox(label="Predicted Quality Score (0–100)", lines=5)
run_button.click(
fn=run_camp_vqa,
inputs=[video_input, intra_cross_experiment, is_finetune_checkbox, train_dataset, test_dataset, model_dropdown],
outputs=output_box,
api_name="run",
queue=True
)
demo.launch()