Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from peft import PeftModel | |
| # --- CONFIGURATION --- | |
| # Replace with your specific repo name | |
| ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora" | |
| BASE_MODEL_ID = "distilbert-base-uncased" | |
| CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} | |
| def load_model(): | |
| print("Loading Base Model...") | |
| # 1. Load the Base Model (Generic DistilBERT) | |
| base_model = AutoModelForSequenceClassification.from_pretrained( | |
| BASE_MODEL_ID, | |
| num_labels=len(CLASS_NAMES), | |
| id2label={k: v for k, v in CLASS_NAMES.items()}, | |
| label2id={v: k for k, v in CLASS_NAMES.items()} | |
| ) | |
| # 2. Load the Tokenizer from YOUR repo (ensures consistency) | |
| tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO) | |
| # 3. Load and Apply your LoRA Adapters | |
| print(f"Loading LoRA Adapters from {ADAPTER_REPO}...") | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_REPO) | |
| # Optimize for CPU (Free Tier Spaces are CPU) | |
| device = torch.device("cpu") | |
| model.to(device) | |
| model.eval() | |
| return model, tokenizer, device | |
| # Load model once on startup | |
| model, tokenizer, device = load_model() | |
| def classify_news(text): | |
| if not text: | |
| return None | |
| # Preprocess | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=128 | |
| ).to(device) | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Get Probabilities | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1).squeeze().cpu().numpy() | |
| # Format Output | |
| results = {} | |
| for i, prob in enumerate(probabilities): | |
| results[CLASS_NAMES[i]] = float(prob) | |
| return results | |
| # Create Interface | |
| iface = gr.Interface( | |
| fn=classify_news, | |
| inputs=gr.Textbox( | |
| lines=5, | |
| placeholder="Paste a news article here...", | |
| label="News Text" | |
| ), | |
| outputs=gr.Label(num_top_classes=4, label="Prediction"), | |
| title="AI News Classifier (DistilBERT + LoRA)", | |
| description="This model classifies news into World, Sports, Business, or Sci/Tech categories. Trained on AG News using Parameter-Efficient Fine-Tuning.", | |
| examples=[ | |
| ["The stock market rallied today as tech companies reported record profits."], | |
| ["The team scored a goal in the final minute to win the championship."], | |
| ["New research shows that drinking coffee may increase life expectancy."], | |
| ["Diplomats gathered in Geneva to discuss the peace treaty."] | |
| ] | |
| ) | |
| iface.launch() |