Joaquin Villar
Update app.py
0f53989 verified
raw
history blame
2.68 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel
# --- CONFIGURATION ---
# Replace with your specific repo name
ADAPTER_REPO = "jvillar-sheff/ag-news-distilbert-lora"
BASE_MODEL_ID = "distilbert-base-uncased"
CLASS_NAMES = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
def load_model():
print("Loading Base Model...")
# 1. Load the Base Model (Generic DistilBERT)
base_model = AutoModelForSequenceClassification.from_pretrained(
BASE_MODEL_ID,
num_labels=len(CLASS_NAMES),
id2label={k: v for k, v in CLASS_NAMES.items()},
label2id={v: k for k, v in CLASS_NAMES.items()}
)
# 2. Load the Tokenizer from YOUR repo (ensures consistency)
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO)
# 3. Load and Apply your LoRA Adapters
print(f"Loading LoRA Adapters from {ADAPTER_REPO}...")
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
# Optimize for CPU (Free Tier Spaces are CPU)
device = torch.device("cpu")
model.to(device)
model.eval()
return model, tokenizer, device
# Load model once on startup
model, tokenizer, device = load_model()
def classify_news(text):
if not text:
return None
# Preprocess
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=128
).to(device)
# Predict
with torch.no_grad():
outputs = model(**inputs)
# Get Probabilities
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
# Format Output
results = {}
for i, prob in enumerate(probabilities):
results[CLASS_NAMES[i]] = float(prob)
return results
# Create Interface
iface = gr.Interface(
fn=classify_news,
inputs=gr.Textbox(
lines=5,
placeholder="Paste a news article here...",
label="News Text"
),
outputs=gr.Label(num_top_classes=4, label="Prediction"),
title="AI News Classifier (DistilBERT + LoRA)",
description="This model classifies news into World, Sports, Business, or Sci/Tech categories. Trained on AG News using Parameter-Efficient Fine-Tuning.",
examples=[
["The stock market rallied today as tech companies reported record profits."],
["The team scored a goal in the final minute to win the championship."],
["New research shows that drinking coffee may increase life expectancy."],
["Diplomats gathered in Geneva to discuss the peace treaty."]
]
)
iface.launch()