import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import json import os from datetime import datetime from torch.nn.utils.rnn import pad_sequence import firebase_admin from firebase_admin import credentials, firestore # Define the model architecture class CTCTransliterator(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3, dropout=0.3, upsample_factor=3): super().__init__() self.embed = nn.Embedding(input_dim, hidden_dim, padding_idx=0) self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers, bidirectional=True, dropout=dropout) self.layer_norm = nn.LayerNorm(hidden_dim * 2) self.dropout = nn.Dropout(dropout) self.upsample_factor = upsample_factor self.fc = nn.Linear(hidden_dim * 2, output_dim) def forward(self, x): # x: (seq_len, batch, ...) x = self.embed(x) x, _ = self.lstm(x) x = self.layer_norm(x) x = self.dropout(x) # (seq_len, batch, hidden) → (batch, hidden, seq_len) x = x.permute(1, 2, 0) x = F.interpolate(x, scale_factor=self.upsample_factor, mode='linear', align_corners=False) # → (batch, hidden, seq_len*upsample_factor) x = x.permute(2, 0, 1) # back to (seq_len*upsample_factor, batch, hidden) x = self.fc(x) x = x.log_softmax(dim=2) return x # Firebase Cache System class FirebaseCache: def __init__(self): self.db = None self.init_firebase() def init_firebase(self): """Initialize Firebase connection""" try: # Try to initialize Firebase if not firebase_admin._apps: # For HuggingFace Spaces, use environment variables if os.getenv('FIREBASE_CREDENTIALS'): # Parse credentials from environment variable import base64 cred_data = json.loads(base64.b64decode(os.getenv('FIREBASE_CREDENTIALS')).decode()) cred = credentials.Certificate(cred_data) elif os.path.exists('firebase-credentials.json'): # For local development cred = credentials.Certificate('firebase-credentials.json') else: print("No Firebase credentials found. Using local cache fallback.") return firebase_admin.initialize_app(cred) self.db = firestore.client() print("Firebase initialized successfully!") else: self.db = firestore.client() except Exception as e: print(f"Firebase initialization failed: {e}") print("Falling back to local cache mode") self.db = None def _create_cache_key(self, input_text, direction): """Create a safe document key for Firestore""" import hashlib # Create hash to handle special characters and length limits key = f"{input_text}_{direction}" return hashlib.md5(key.encode()).hexdigest() def get(self, input_text, direction): """Get cached translation from Firebase""" if not self.db: return None try: doc_key = self._create_cache_key(input_text, direction) doc = self.db.collection('translations').document(doc_key).get() if doc.exists: data = doc.to_dict() # Update usage count self.db.collection('translations').document(doc_key).update({ 'usage_count': data.get('usage_count', 0) + 1, 'last_used': datetime.now() }) print(f"Cache hit: {input_text}") return data.get('output', '') return None except Exception as e: print(f"Cache read error: {e}") return None def set(self, input_text, direction, output): """Store translation in Firebase""" if not self.db: return False try: doc_key = self._create_cache_key(input_text, direction) doc_data = { 'input': input_text, 'direction': direction, 'output': output, 'corrected_output': '', 'timestamp': datetime.now(), 'last_used': datetime.now(), 'usage_count': 1 } self.db.collection('translations').document(doc_key).set(doc_data) print(f"Cached: {input_text} → {output}") return True except Exception as e: print(f"Cache write error: {e}") return False def update_correction(self, input_text, direction, corrected_output): """Update translation with user correction""" if not self.db: return False try: doc_key = self._create_cache_key(input_text, direction) self.db.collection('translations').document(doc_key).update({ 'corrected_output': corrected_output, 'correction_timestamp': datetime.now() }) print(f"Correction saved: {input_text} → {corrected_output}") return True except Exception as e: print(f"Correction save error: {e}") return False def get_stats(self): """Get cache statistics""" if not self.db: return "Firebase not connected" try: docs = self.db.collection('translations').get() total = len(docs) corrected = 0 total_usage = 0 for doc in docs: data = doc.to_dict() if data.get('corrected_output'): corrected += 1 total_usage += data.get('usage_count', 0) return f""" Cache Statistics: • Total translations: {total} • With corrections: {corrected} • Total usage count: {total_usage} • Average usage: {total_usage/total if total > 0 else 0:.1f} per translation """.strip() except Exception as e: return f"Error getting stats: {e}" # Load vocabularies and model def load_model_and_vocabs(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load vocabularies with open('latin_stoi.json', 'r', encoding='utf-8') as f: latin_stoi = json.load(f) with open('latin_itos.json', 'r', encoding='utf-8') as f: latin_itos = json.load(f) with open('arabic_stoi.json', 'r', encoding='utf-8') as f: arabic_stoi = json.load(f) with open('arabic_itos.json', 'r', encoding='utf-8') as f: arabic_itos= json.load(f) # Initialize model model = CTCTransliterator( len(latin_stoi), 256, len(arabic_stoi), num_layers=3, dropout=0.3, upsample_factor=2 ).to(device) # Load trained weights model.load_state_dict(torch.load('best_model.pth', map_location=device)) model.eval() blank_id = arabic_stoi.get('', len(arabic_itos)-1) return model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device # Load everything at startup model, latin_stoi, latin_itos, arabic_stoi, arabic_itos, blank_id, device = load_model_and_vocabs() firebase_cache = FirebaseCache() def encode_text(text, vocab): """Encode text using vocabulary""" return torch.tensor([vocab.get(ch, 0) for ch in text.strip()], dtype=torch.long) def greedy_decode(log_probs, blank_id, itos, stoi): """ Decode CTC outputs using greedy decoding. """ eos_id = stoi.get('', len(stoi)-2) preds = log_probs.argmax(2).T.cpu().numpy() # (B, T) results = [] raw_results = [] print(eos_id, blank_id) print(stoi) print(type(blank_id)) print(stoi.get('',0)) for i, pred in enumerate(preds): prev = None decoded = [] raw_result = [] for p in pred: print(p, itos[str(p)]) if p == eos_id: # Stop at EOS! break # CTC collapse: skip blanks and repeated characters if p != blank_id and p != prev: decoded.append(itos[str(p)]) prev = p raw_result.append(itos[str(p)]) results.append("".join(decoded)) raw_results.append("".join(raw_result)) print(results, raw_results) return results def transliterate_latin_to_arabic(text): """Transliterate Latin script to Arabic script with Firebase caching""" if not text.strip(): return "" # Check Firebase cache first cached_result = firebase_cache.get(text, "Latin → Arabic") if cached_result: return cached_result try: # Encode input text src = encode_text(text, latin_stoi).unsqueeze(1).to(device) # Generate prediction with torch.no_grad(): out = model(src) # Decode output decoded = greedy_decode(out, blank_id, arabic_itos, arabic_stoi) result = decoded[0] if decoded else "" # Cache the result in Firebase firebase_cache.set(text, "Latin → Arabic", result) return result except Exception as e: return f"Error: {str(e)}" def transliterate_arabic_to_latin(text): """Transliterate Arabic script to Latin script (placeholder)""" return "Arabic to Latin transliteration not implemented yet." def transliterate(text, direction): """Main transliteration function""" if direction == "Latin → Arabic": return transliterate_latin_to_arabic(text.lower()) else: return transliterate_arabic_to_latin(text) def save_correction(input_text, direction, corrected_output): """Save user correction to Firebase""" if firebase_cache.update_correction(input_text, direction, corrected_output): return "Correction saved to the database! Thank you for improving the model." else: return "Could not save correction to databse." # Arabic keyboard layout arabic_keys = [ ['ض', 'ص', 'ث', 'ق', 'ف', 'غ', 'ع', 'ه', 'خ', 'ح', 'ج', 'د'], ['ش', 'س', 'ي', 'ب', 'ل', 'ا', 'ت', 'ن', 'م', 'ك', 'ط'], ['ئ', 'ء', 'ؤ', 'ر', 'لا', 'ى', 'ة', 'و', 'ز', 'ظ'], ['ذ', '١', '٢', '٣', '٤', '٥', '٦', '٧', '٨', '٩', '٠'] ] # Create Gradio interface def create_interface(): with gr.Blocks(title="Darija Transliterator", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Darija Transliterator Convert between Latin script and Arabic script for Moroccan Darija **Firebase-Powered**: Persistent caching across sessions **Arabic Keyboard**: Built-in Arabic keyboard for corrections **Real-time Stats**: Live usage analytics """ ) # Stats section with gr.Row(): stats_btn = gr.Button("Show Statistics", variant="secondary") stats_display = gr.Textbox( label="Firebase Statistics", interactive=False, visible=False, lines=5 ) with gr.Row(): with gr.Column(scale=1): direction = gr.Radio( choices=["Latin → Arabic", "Arabic → Latin"], value="Latin → Arabic", label="Translation Direction" ) input_text = gr.Textbox( placeholder="Enter text to transliterate...", label="Input Text", lines=4, max_lines=10 ) with gr.Row(): clear_btn = gr.Button("Clear", variant="secondary") translate_btn = gr.Button("Transliterate", variant="primary") with gr.Column(scale=1): output_text = gr.Textbox( label="Output", lines=4, max_lines=10, interactive=True ) # Arabic Keyboard gr.Markdown("### Arabic Keyboard") gr.Markdown("*Click letters to edit the output text above*") with gr.Group(): for row in arabic_keys: with gr.Row(): for char in row: btn = gr.Button(char, size="sm", scale=1) btn.click( fn=None, js=f"(output_text) => output_text + '{char}'", inputs=[output_text], outputs=[output_text], show_progress=False, queue=False ) with gr.Row(): space_btn = gr.Button("Space", size="sm", scale=2) backspace_btn = gr.Button("⌫ Backspace", size="sm", scale=2) clear_output_btn = gr.Button("Clear Output", size="sm", scale=2) # Correction system with gr.Group(): gr.Markdown("### Correction System") correction_status = gr.Textbox( label="Status", interactive=False, visible=False ) save_correction_btn = gr.Button("Save Correction", variant="secondary") # Keyboard utility buttons space_btn.click( fn=None, js="(output_text) => output_text + ' '", inputs=[output_text], outputs=[output_text], show_progress=False, queue=False ) backspace_btn.click( fn=None, js="(output_text) => output_text.slice(0, -1)", inputs=[output_text], outputs=[output_text], show_progress=False, queue=False ) clear_output_btn.click( fn=None, js="() => ''", outputs=[output_text], show_progress=False, queue=False ) # Stats button stats_btn.click( fn=firebase_cache.get_stats, outputs=[stats_display] ).then( fn=lambda: gr.update(visible=True), outputs=[stats_display] ) # Example inputs gr.Markdown("### Examples") examples = [ ["makay3nich bli katkhdam bzaf", "Latin → Arabic"], ["rah bayn dkchi li katdir kolchi 3ay9 bik", "Latin → Arabic"], ["wach na9dar nakhod caipirinha, 3afak", "Latin → Arabic"], ["ghadi temchi f lkhedma mzyan", "Latin → Arabic"] ] gr.Examples( examples=examples, inputs=[input_text, direction], outputs=output_text, fn=transliterate, cache_examples=False ) # Event handlers translate_btn.click( fn=transliterate, inputs=[input_text, direction], outputs=output_text ).then( fn=lambda: gr.update(visible=True), outputs=[correction_status] ) clear_btn.click( fn=lambda: ("", ""), outputs=[input_text, output_text] ) input_text.submit( fn=transliterate, inputs=[input_text, direction], outputs=output_text ) save_correction_btn.click( fn=save_correction, inputs=[input_text, direction, output_text], outputs=[correction_status] ).then( fn=lambda: gr.update(visible=True), outputs=[correction_status] ) # Information gr.Markdown( """ ### About This model transliterates Moroccan Darija between Latin and Arabic scripts using a CTC-based neural network. **Firebase Features:** - **Persistent Storage**: All translations are saved permanently - **Analytics**: Track usage patterns and popular translations - **Fast Responses**: Cached results load instantly - **Global Access**: Data synced across all users - **Corrections**: Help improve the model by fixing outputs **How to help improve the model:** 1. Use the Arabic keyboard to correct any wrong translations 2. Click "Save Correction" to store your improvement 3. Your corrections help train better models for everyone! """ ) return demo # Launch the app if __name__ == "__main__": demo = create_interface() demo.launch(share=True)