|
|
|
|
|
""" |
|
|
Improved quick test script for Sanskrit multimodal model |
|
|
Uses better prompting to get actual Sanskrit text transcription |
|
|
""" |
|
|
|
|
|
import json |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
from PIL import Image |
|
|
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
|
|
from qwen_vl_utils import process_vision_info |
|
|
from peft import PeftModel |
|
|
import numpy as np |
|
|
from typing import List, Dict |
|
|
import re |
|
|
import os |
|
|
|
|
|
def load_model_and_processor(model_path: str, adapter_path: str = None): |
|
|
"""Load the base model and processor""" |
|
|
print("Loading processor...") |
|
|
processor = AutoProcessor.from_pretrained(model_path) |
|
|
|
|
|
print("Loading base model...") |
|
|
|
|
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map={"": 0} |
|
|
) |
|
|
|
|
|
if adapter_path and os.path.exists(adapter_path): |
|
|
print("Loading LoRA adapters...") |
|
|
model = PeftModel.from_pretrained(model, adapter_path) |
|
|
else: |
|
|
print("No adapter path found, using base model only") |
|
|
|
|
|
model.eval() |
|
|
print(f"Model loaded on device: {next(model.parameters()).device}") |
|
|
return model, processor |
|
|
|
|
|
def decode_base64_image(base64_string: str) -> Image.Image: |
|
|
"""Decode base64 string to PIL Image""" |
|
|
if base64_string.startswith('data:image'): |
|
|
base64_string = base64_string.split(',')[1] |
|
|
|
|
|
image_data = base64.b64decode(base64_string) |
|
|
image = Image.open(io.BytesIO(image_data)) |
|
|
return image |
|
|
|
|
|
def preprocess_sanskrit_text(text: str) -> str: |
|
|
"""Preprocess Sanskrit text for evaluation""" |
|
|
text = re.sub(r'\s+', ' ', text.strip()) |
|
|
return text |
|
|
|
|
|
def calculate_exact_match(predicted: str, ground_truth: str) -> bool: |
|
|
"""Calculate exact match accuracy""" |
|
|
predicted = preprocess_sanskrit_text(predicted) |
|
|
ground_truth = preprocess_sanskrit_text(ground_truth) |
|
|
return predicted == ground_truth |
|
|
|
|
|
def calculate_character_accuracy(predicted: str, ground_truth: str) -> float: |
|
|
"""Calculate character-level accuracy using edit distance""" |
|
|
predicted = preprocess_sanskrit_text(predicted) |
|
|
ground_truth = preprocess_sanskrit_text(ground_truth) |
|
|
|
|
|
if not ground_truth: |
|
|
return 1.0 if not predicted else 0.0 |
|
|
|
|
|
m, n = len(predicted), len(ground_truth) |
|
|
dp = [[0] * (n + 1) for _ in range(m + 1)] |
|
|
|
|
|
for i in range(m + 1): |
|
|
dp[i][0] = i |
|
|
for j in range(n + 1): |
|
|
dp[0][j] = j |
|
|
|
|
|
for i in range(1, m + 1): |
|
|
for j in range(1, n + 1): |
|
|
if predicted[i-1] == ground_truth[j-1]: |
|
|
dp[i][j] = dp[i-1][j-1] |
|
|
else: |
|
|
dp[i][j] = 1 + min(dp[i-1][j], dp[i][j-1], dp[i-1][j-1]) |
|
|
|
|
|
edit_distance = dp[m][n] |
|
|
max_length = max(m, n) |
|
|
accuracy = 1.0 - (edit_distance / max_length) if max_length > 0 else 1.0 |
|
|
return max(0.0, accuracy) |
|
|
|
|
|
def calculate_token_jaccard(predicted: str, ground_truth: str) -> float: |
|
|
"""Calculate token-level Jaccard similarity""" |
|
|
predicted = preprocess_sanskrit_text(predicted) |
|
|
ground_truth = preprocess_sanskrit_text(ground_truth) |
|
|
|
|
|
pred_tokens = set(predicted.split()) |
|
|
gt_tokens = set(ground_truth.split()) |
|
|
|
|
|
if not pred_tokens and not gt_tokens: |
|
|
return 1.0 |
|
|
|
|
|
intersection = len(pred_tokens & gt_tokens) |
|
|
union = len(pred_tokens | gt_tokens) |
|
|
|
|
|
return intersection / union if union > 0 else 0.0 |
|
|
|
|
|
def generate_response(model, processor, image: Image.Image, prompt: str) -> str: |
|
|
"""Generate response from the model""" |
|
|
try: |
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "image", "image": image}, |
|
|
{"type": "text", "text": prompt} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
|
|
|
model_device = next(model.parameters()).device |
|
|
inputs = {k: v.to(model_device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
do_sample=False, |
|
|
pad_token_id=processor.tokenizer.eos_token_id, |
|
|
use_cache=True, |
|
|
repetition_penalty=1.1 |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs['input_ids'], generated_ids) |
|
|
] |
|
|
output_text = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
) |
|
|
|
|
|
return output_text[0] if output_text else "" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generating response: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return "" |
|
|
|
|
|
def main(): |
|
|
|
|
|
model_path = 'Qwen/Qwen2.5-VL-7B-Instruct' |
|
|
adapter_path = './outputs/out-qwen2-5-vl' |
|
|
test_data_path = 'sanskrit_multimodal_test.json' |
|
|
|
|
|
|
|
|
prompts = [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"Please transcribe the Sanskrit text shown in this image:" |
|
|
] |
|
|
|
|
|
max_samples = 3 |
|
|
|
|
|
print("Loading test data...") |
|
|
with open(test_data_path, 'r', encoding='utf-8') as f: |
|
|
test_data = json.load(f) |
|
|
|
|
|
test_data = test_data[:max_samples] |
|
|
print(f"Testing on {len(test_data)} samples") |
|
|
|
|
|
|
|
|
model, processor = load_model_and_processor(model_path, adapter_path) |
|
|
|
|
|
|
|
|
for prompt_idx, prompt in enumerate(prompts): |
|
|
print(f"\n{'='*60}") |
|
|
print(f"TESTING PROMPT {prompt_idx + 1}: {prompt}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
exact_matches = 0 |
|
|
character_accuracies = [] |
|
|
token_jaccards = [] |
|
|
failed_predictions = 0 |
|
|
|
|
|
for i, sample in enumerate(test_data): |
|
|
print(f"\n--- Sample {i+1} ---") |
|
|
|
|
|
try: |
|
|
|
|
|
ground_truth = "" |
|
|
for message in sample['messages']: |
|
|
if message['role'] == 'assistant': |
|
|
for content in message['content']: |
|
|
if content['type'] == 'text': |
|
|
ground_truth = content['text'] |
|
|
break |
|
|
break |
|
|
|
|
|
|
|
|
image = None |
|
|
for message in sample['messages']: |
|
|
if message['role'] == 'user': |
|
|
for content in message['content']: |
|
|
if content['type'] == 'image': |
|
|
image = decode_base64_image(content['base64']) |
|
|
break |
|
|
break |
|
|
|
|
|
if image is None: |
|
|
print("No image found") |
|
|
failed_predictions += 1 |
|
|
continue |
|
|
|
|
|
print(f"Ground Truth: {ground_truth}") |
|
|
|
|
|
|
|
|
predicted = generate_response(model, processor, image, prompt) |
|
|
print(f"Predicted: {predicted}") |
|
|
|
|
|
if not predicted: |
|
|
print("Empty prediction") |
|
|
failed_predictions += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
exact_match = calculate_exact_match(predicted, ground_truth) |
|
|
char_accuracy = calculate_character_accuracy(predicted, ground_truth) |
|
|
token_jaccard = calculate_token_jaccard(predicted, ground_truth) |
|
|
|
|
|
print(f"Exact Match: {exact_match}") |
|
|
print(f"Character Accuracy: {char_accuracy:.4f}") |
|
|
print(f"Token Jaccard: {token_jaccard:.4f}") |
|
|
|
|
|
if exact_match: |
|
|
exact_matches += 1 |
|
|
|
|
|
character_accuracies.append(char_accuracy) |
|
|
token_jaccards.append(token_jaccard) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing sample: {e}") |
|
|
failed_predictions += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
successful_samples = len(test_data) - failed_predictions |
|
|
|
|
|
if successful_samples > 0: |
|
|
exact_match_accuracy = exact_matches / successful_samples |
|
|
avg_char_accuracy = np.mean(character_accuracies) |
|
|
avg_token_jaccard = np.mean(token_jaccards) |
|
|
else: |
|
|
exact_match_accuracy = 0.0 |
|
|
avg_char_accuracy = 0.0 |
|
|
avg_token_jaccard = 0.0 |
|
|
|
|
|
|
|
|
print(f"\n--- RESULTS FOR PROMPT {prompt_idx + 1} ---") |
|
|
print(f"Exact Match Accuracy: {exact_match_accuracy:.4f}") |
|
|
print(f"Average Character Accuracy: {avg_char_accuracy:.4f}") |
|
|
print(f"Average Token Jaccard: {avg_token_jaccard:.4f}") |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("ALL PROMPTS TESTED") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|