Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| import os | |
| from torchvision import transforms | |
| from PIL import Image | |
| from transformers import ConvNextV2ForImageClassification | |
| # --- Configuration --- | |
| # Paths are relative to the app's root directory in the Hugging Face Space | |
| DATA_DIR = '.' | |
| LIST_DIR = os.path.join(DATA_DIR, 'list') | |
| MODEL_PATH_HERBARIUM = os.path.join(DATA_DIR, 'herbarium_convnext_v2_base.pth') | |
| SPECIES_LIST_TXT = os.path.join(LIST_DIR, 'species_list.txt') | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # --- Load Species Information --- | |
| try: | |
| species_df = pd.read_csv(SPECIES_LIST_TXT, sep=';', header=None, names=['class_id', 'species_name']) | |
| class_names = list(species_df['species_name']) | |
| num_labels = len(class_names) | |
| except FileNotFoundError: | |
| # Fallback if the species list is not found | |
| class_names = [f"Class {i}" for i in range(100)] # Assuming 100 classes as a fallback | |
| num_labels = 100 | |
| print(f"Warning: '{SPECIES_LIST_TXT}' not found. Using generic class names.") | |
| # --- Image Transformations --- | |
| data_transforms = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # --- Model Loading --- | |
| def load_herbarium_model(): | |
| """Loads the Herbarium ConvNextV2 model.""" | |
| model = ConvNextV2ForImageClassification.from_pretrained( | |
| "facebook/convnextv2-base-22k-224", | |
| num_labels=num_labels, | |
| ignore_mismatched_sizes=True | |
| ) | |
| try: | |
| # Load the state dictionary | |
| model.load_state_dict(torch.load(MODEL_PATH_HERBARIUM, map_location=DEVICE)) | |
| except FileNotFoundError: | |
| print(f"Warning: Model weights not found at '{MODEL_PATH_HERBARIUM}'. The model is using pre-trained weights, not fine-tuned ones.") | |
| except Exception as e: | |
| print(f"Error loading model weights: {e}. The model is using pre-trained weights.") | |
| model = model.to(DEVICE) | |
| model.eval() | |
| return model | |
| # Load the primary model | |
| herbarium_model = load_herbarium_model() | |
| # --- Prediction Functions --- | |
| def predict_herbarium(image): | |
| """Runs inference on the herbarium model.""" | |
| if image is None: | |
| return "Please upload an image." | |
| # Preprocess the image | |
| image = data_transforms(image).unsqueeze(0) | |
| image = image.to(DEVICE) | |
| # Get model predictions | |
| with torch.no_grad(): | |
| outputs = herbarium_model(image).logits | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] | |
| # Get top 5 predictions | |
| top5_prob, top5_indices = torch.topk(probabilities, 5) | |
| # Format results | |
| results = {class_names[i]: f"{p:.3f}" for i, p in zip(top5_indices, top5_prob)} | |
| return results | |
| def predict_placeholder_1(image): | |
| """Placeholder function for the second model.""" | |
| if image is None: | |
| return "Please upload an image." | |
| return "Model 2 is not available yet. Please check back later." | |
| def predict_placeholder_2(image): | |
| """Placeholder function for the third model.""" | |
| if image is None: | |
| return "Please upload an image." | |
| return "Model 3 is not available yet. Please check back later." | |
| # --- Main Prediction Logic --- | |
| def predict(model_choice, image): | |
| """Routes the prediction to the chosen model.""" | |
| if model_choice == "Herbarium Species Classifier": | |
| return predict_herbarium(image) | |
| elif model_choice == "Future Model 1 (Placeholder)": | |
| return predict_placeholder_1(image) | |
| elif model_choice == "Future Model 2 (Placeholder)": | |
| return predict_placeholder_2(image) | |
| else: | |
| return "Invalid model selected." | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌿 Plant Species Classification | |
| ## AML Group Project - PsychicFireSong | |
| Upload an image of a plant to classify it. Select a model from the dropdown below. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| label="Select Model", | |
| choices=[ | |
| "Herbarium Species Classifier", | |
| "Future Model 1 (Placeholder)", | |
| "Future Model 2 (Placeholder)" | |
| ], | |
| value="Herbarium Species Classifier" | |
| ) | |
| image_input = gr.Image(type="pil", label="Upload Plant Image") | |
| submit_button = gr.Button("Classify", variant="primary") | |
| with gr.Column(scale=1): | |
| output_label = gr.Label(label="Top 5 Predictions", num_top_classes=5) | |
| submit_button.click( | |
| fn=predict, | |
| inputs=[model_selector, image_input], | |
| outputs=output_label | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| # Add paths to example images if you have any in your project | |
| # e.g., os.path.join("examples", "example1.jpg") | |
| ], | |
| inputs=image_input, | |
| outputs=output_label, | |
| fn=lambda img: predict("Herbarium Species Classifier", img), | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |
| demo.launch() | |