#!/usr/bin/env python3 """ Convert the Sanskrit OCR dataset to Axolotl multi-modal chat_template format """ import json import pandas as pd from datasets import load_dataset import random from pathlib import Path import base64 import io from PIL import Image def generate_questions_for_sanskrit_text(): """Generate various questions that could be asked about Sanskrit text in images""" questions = [ "What Sanskrit text is written in this image?", "Can you read the Sanskrit text shown in this image?", "What does this Sanskrit text say?", "Please transcribe the Sanskrit text from this image.", "What Sanskrit verses are visible in this image?", "Can you identify the Sanskrit text in this image?", "What is the Sanskrit content of this image?", "Please read the Sanskrit text displayed here.", "What Sanskrit shlok is written in this image?", "Can you translate the Sanskrit text in this image?", "What Sanskrit words are shown in this image?", "Please provide the Sanskrit text from this image.", "What Sanskrit script is visible in this image?", "Can you extract the Sanskrit text from this image?", "What Sanskrit verses can you see in this image?" ] return questions def image_to_base64(image): """Convert PIL image to base64 string""" buffer = io.BytesIO() image.save(buffer, format='JPEG') img_str = base64.b64encode(buffer.getvalue()).decode() return img_str def create_multimodal_entry(image, sanskrit_text, question=None): """Create a single entry in Axolotl multi-modal chat_template format""" if question is None: questions = generate_questions_for_sanskrit_text() question = random.choice(questions) # Convert image to base64 image_base64 = image_to_base64(image) # Axolotl multi-modal chat_template format entry = { "messages": [ { "role": "user", "content": [ { "type": "text", "text": question }, { "type": "image", "base64": image_base64 } ] }, { "role": "assistant", "content": [ { "type": "text", "text": sanskrit_text } ] } ] } return entry def convert_to_multimodal_format(): """Convert the Sanskrit OCR dataset to Axolotl multi-modal format""" print("Loading dataset...") ds = load_dataset('snskrt/Sanskrit_OCR_Parallel_Corpus', split='train') labels_df = pd.read_csv('Sanskrit_OCR_Parallel_Corpus_train/labels.csv') print(f"Dataset has {len(ds)} images and {len(labels_df)} labels") # Create train and test datasets train_entries = [] test_entries = [] # Set random seed for reproducible splits random.seed(42) indices = list(range(len(ds))) random.shuffle(indices) # Split: 95% train, 5% test train_size = int(len(ds) * 0.95) train_indices = indices[:train_size] test_indices = indices[train_size:] print(f"Creating training dataset ({len(train_indices)} entries)...") for i, idx in enumerate(train_indices): if idx < len(labels_df): sanskrit_text = labels_df.iloc[idx]['shlok'] entry = create_multimodal_entry(ds[idx]['image'], sanskrit_text) train_entries.append(entry) if (i + 1) % 100 == 0: print(f"Processed {i + 1} training entries...") print(f"Creating test dataset ({len(test_indices)} entries)...") for i, idx in enumerate(test_indices): if idx < len(labels_df): sanskrit_text = labels_df.iloc[idx]['shlok'] entry = create_multimodal_entry(ds[idx]['image'], sanskrit_text) test_entries.append(entry) if (i + 1) % 100 == 0: print(f"Processed {i + 1} test entries...") # Save training dataset train_file = "sanskrit_multimodal_train.json" print(f"Saving training dataset to {train_file}...") with open(train_file, 'w', encoding='utf-8') as f: json.dump(train_entries, f, ensure_ascii=False, indent=2) # Save test dataset test_file = "sanskrit_multimodal_test.json" print(f"Saving test dataset to {test_file}...") with open(test_file, 'w', encoding='utf-8') as f: json.dump(test_entries, f, ensure_ascii=False, indent=2) # Create sample files for inspection train_sample_file = "sanskrit_multimodal_train_sample.json" test_sample_file = "sanskrit_multimodal_test_sample.json" print(f"Creating sample files...") with open(train_sample_file, 'w', encoding='utf-8') as f: json.dump(train_entries[:3], f, ensure_ascii=False, indent=2) with open(test_sample_file, 'w', encoding='utf-8') as f: json.dump(test_entries[:3], f, ensure_ascii=False, indent=2) print(f"\nMulti-modal format conversion complete!") print(f"Training dataset: {train_file} ({len(train_entries)} entries)") print(f"Test dataset: {test_file} ({len(test_entries)} entries)") print(f"Training sample: {train_sample_file}") print(f"Test sample: {test_sample_file}") return train_entries, test_entries if __name__ == "__main__": train_entries, test_entries = convert_to_multimodal_format() print(f"Conversion complete! Created {len(train_entries)} training and {len(test_entries)} test entries.")