PsychicFireSong's picture
Fix: Expose Gradio app for Hugging Face Spaces
948f69d
import gradio as gr
import torch
import pandas as pd
import os
from torchvision import transforms
from PIL import Image
from transformers import ConvNextV2ForImageClassification
# --- Configuration ---
# Paths are relative to the app's root directory in the Hugging Face Space
DATA_DIR = '.'
LIST_DIR = os.path.join(DATA_DIR, 'list')
MODEL_PATH_HERBARIUM = os.path.join(DATA_DIR, 'herbarium_convnext_v2_base.pth')
SPECIES_LIST_TXT = os.path.join(LIST_DIR, 'species_list.txt')
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --- Load Species Information ---
try:
species_df = pd.read_csv(SPECIES_LIST_TXT, sep=';', header=None, names=['class_id', 'species_name'])
class_names = list(species_df['species_name'])
num_labels = len(class_names)
except FileNotFoundError:
# Fallback if the species list is not found
class_names = [f"Class {i}" for i in range(100)] # Assuming 100 classes as a fallback
num_labels = 100
print(f"Warning: '{SPECIES_LIST_TXT}' not found. Using generic class names.")
# --- Image Transformations ---
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# --- Model Loading ---
def load_herbarium_model():
"""Loads the Herbarium ConvNextV2 model."""
model = ConvNextV2ForImageClassification.from_pretrained(
"facebook/convnextv2-base-22k-224",
num_labels=num_labels,
ignore_mismatched_sizes=True
)
try:
# Load the state dictionary
model.load_state_dict(torch.load(MODEL_PATH_HERBARIUM, map_location=DEVICE))
except FileNotFoundError:
print(f"Warning: Model weights not found at '{MODEL_PATH_HERBARIUM}'. The model is using pre-trained weights, not fine-tuned ones.")
except Exception as e:
print(f"Error loading model weights: {e}. The model is using pre-trained weights.")
model = model.to(DEVICE)
model.eval()
return model
# Load the primary model
herbarium_model = load_herbarium_model()
# --- Prediction Functions ---
def predict_herbarium(image):
"""Runs inference on the herbarium model."""
if image is None:
return "Please upload an image."
# Preprocess the image
image = data_transforms(image).unsqueeze(0)
image = image.to(DEVICE)
# Get model predictions
with torch.no_grad():
outputs = herbarium_model(image).logits
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
# Get top 5 predictions
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Format results
results = {class_names[i]: f"{p:.3f}" for i, p in zip(top5_indices, top5_prob)}
return results
def predict_placeholder_1(image):
"""Placeholder function for the second model."""
if image is None:
return "Please upload an image."
return "Model 2 is not available yet. Please check back later."
def predict_placeholder_2(image):
"""Placeholder function for the third model."""
if image is None:
return "Please upload an image."
return "Model 3 is not available yet. Please check back later."
# --- Main Prediction Logic ---
def predict(model_choice, image):
"""Routes the prediction to the chosen model."""
if model_choice == "Herbarium Species Classifier":
return predict_herbarium(image)
elif model_choice == "Future Model 1 (Placeholder)":
return predict_placeholder_1(image)
elif model_choice == "Future Model 2 (Placeholder)":
return predict_placeholder_2(image)
else:
return "Invalid model selected."
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🌿 Plant Species Classification
## AML Group Project - PsychicFireSong
Upload an image of a plant to classify it. Select a model from the dropdown below.
"""
)
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(
label="Select Model",
choices=[
"Herbarium Species Classifier",
"Future Model 1 (Placeholder)",
"Future Model 2 (Placeholder)"
],
value="Herbarium Species Classifier"
)
image_input = gr.Image(type="pil", label="Upload Plant Image")
submit_button = gr.Button("Classify", variant="primary")
with gr.Column(scale=1):
output_label = gr.Label(label="Top 5 Predictions", num_top_classes=5)
submit_button.click(
fn=predict,
inputs=[model_selector, image_input],
outputs=output_label
)
gr.Examples(
examples=[
# Add paths to example images if you have any in your project
# e.g., os.path.join("examples", "example1.jpg")
],
inputs=image_input,
outputs=output_label,
fn=lambda img: predict("Herbarium Species Classifier", img),
cache_examples=False
)
if __name__ == "__main__":
demo.launch()
demo.launch()