import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import gradio as gr import json # Added for JSON conversion def run_inference(review_text: str) -> str: """ Perform inference on the given wine review text and return the predicted wine variety using ModernBERT, an encoder-only classifier from "spawn99/modernbert-wine-classification". Args: review_text (str): Wine review text in the format "country [SEP] description". Returns: str: The predicted wine variety using the model's id2label mapping if available. """ # Define model and tokenizer identifiers model_id = "spawn99/modernbert-wine-classification" tokenizer_id = "answerdotai/ModernBERT-base" # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) # The model used here is a ModernBERT encoder-only classifier. model = AutoModelForSequenceClassification.from_pretrained(model_id) # Tokenize the input text inputs = tokenizer( review_text, return_tensors="pt", padding="max_length", truncation=True, max_length=256 ) model.eval() with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Determine prediction and map to label if available pred = torch.argmax(logits, dim=-1).item() variety = ( model.config.id2label.get(pred, str(pred)) if hasattr(model.config, "id2label") and model.config.id2label else str(pred) ) return variety def predict_wine_variety(country: str, description: str, output_format: str) -> str: """ Combine the provided country and description, perform inference, and format the output based on the selected output format. Enforces a maximum character limit of 750 on the description. Args: country (str): The country of wine origin. description (str): The wine review description. output_format (str): Either "JSON" to return output as a JSON-formatted string, or "Text" for plain text output. Returns: str: The predicted wine variety formatted as JSON (if selected) or as plain text. """ if len(description) > 750: error_msg = "Description exceeds 750 character limit. Please shorten your input." if output_format.lower() == "json": return json.dumps({"error": error_msg}, indent=2) else: return error_msg # Capitalize input values and format the review text accordingly. review_text = f"{country.capitalize()} [SEP] {description.capitalize()}" predicted_variety = run_inference(review_text) if output_format.lower() == "json": return json.dumps({"Variety": predicted_variety}, indent=2) else: return predicted_variety if __name__ == "__main__": iface = gr.Interface( fn=predict_wine_variety, inputs=[ gr.Textbox(label="Country", placeholder="Enter country of origin..."), gr.Textbox(label="Description", placeholder="Enter wine review description..."), # New radio input to choose between JSON and plain text output formats: gr.Radio(choices=["JSON", "Text"], value="JSON", label="Output Format") ], # Changed outputs to a Textbox so that plain text output shows naturally outputs=gr.Textbox(label="Prediction"), title="Wine Variety Predictor", description=( "Predict the wine variety based on the country and wine review.\n\n" "This tool uses ModernBERT, an encoder-only classifier, trained on the wine reviews dataset\n" "(model: spawn99/modernbert-wine-classification, dataset: spawn99/wine-reviews).\n\n" "Use the Output Format selector to toggle between a JSON-formatted result and a plain text prediction." ) ) iface.launch()