DeepSlide / src /model_downloader.py
harshinde's picture
Upload src/model_downloader.py with huggingface_hub
11f9299 verified
raw
history blame
5.67 kB
import os
import json
import requests
import streamlit as st
from pathlib import Path
from tqdm.auto import tqdm
class ModelDownloader:
def __init__(self):
# Create models directory for caching
self.models_dir = Path("/app/models")
self.models_dir.mkdir(exist_ok=True)
# HuggingFace model repository details
self.hf_model_url = "https://huggingface.co/harshinde/Sims/resolve/main/models/"
# Model mapping with file names
self.model_files = {
"deeplabv3plus": {
"file": "deeplabv3.pth",
"url": f"{self.hf_model_url}deeplabv3.pth"
},
"densenet121": {
"file": "densenet121.pth",
"url": f"{self.hf_model_url}densenet121.pth"
},
"efficientnetb0": {
"file": "efficientnetb0.pth",
"url": f"{self.hf_model_url}efficientnetb0.pth"
},
"inceptionresnetv2": {
"file": "inceptionresnetv2.pth",
"id": "inceptionresnetv2"
},
"inceptionv4": {
"file": "inceptionv4.pth",
"id": "inceptionv4"
},
"mitb1": {
"file": "mitb1.pth",
"id": "mitb1"
},
"mobilenetv2": {
"file": "mobilenetv2.pth",
"id": "mobilenetv2"
},
"resnet34": {
"file": "resnet34.pth",
"id": "resnet34"
},
"resnext50_32x4d": {
"file": "resnext50-32x4d.pth",
"id": "resnext50-32x4d"
},
"se_resnet50": {
"file": "se_resnet50.pth",
"id": "se_resnet50"
},
"se_resnext50_32x4d": {
"file": "se_resnext50_32x4d.pth",
"id": "se_resnext50_32x4d"
},
"segformer": {
"file": "segformer.pth",
"id": "segformer"
},
"vgg16": {
"file": "vgg16.pth",
"id": "vgg16"
}
}
def download_model(self, model_name):
"""
Download model from Hugging Face Models repository
Args:
model_name (str): Name of the model to download
Returns:
str: Path to the downloaded model file
"""
if model_name not in self.model_files:
raise ValueError(f"Model {model_name} not found. Available models: {list(self.model_files.keys())}")
model_info = self.model_files[model_name]
model_path = self.models_dir / model_info['file']
if not model_path.exists():
print(f"Downloading {model_name} model...")
response = requests.get(model_info['url'], stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(model_path, 'wb') as f, tqdm(
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as pbar:
for data in response.iter_content(chunk_size=1024):
size = f.write(data)
pbar.update(size)
print(f"Model downloaded successfully to {model_path}")
return str(model_path)
# If model already exists, return path
if model_path.exists():
return str(model_path)
# Construct download URL for the specific model
download_url = f"{self.kaggle_model_url}/{model_info['id']}/1"
try:
st.info(f"Downloading model {model_name} from Kaggle Models...")
progress_bar = st.progress(0)
# Download with progress
response = requests.get(download_url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
with open(model_path, 'wb') as f:
for i, data in enumerate(response.iter_content(block_size)):
progress_bar.progress(min(i * block_size / total_size, 1.0))
f.write(data)
st.success(f"Successfully downloaded {model_name}")
return str(model_path)
except requests.exceptions.RequestException as e:
raise Exception(f"Failed to download model from Kaggle: {str(e)}")
def get_model_path(self, model_name):
"""
Get the path for a model file, downloading it from Kaggle if necessary
Args:
model_name (str): Name of the model (e.g., 'deeplabv3plus', 'densenet121', etc.)
Returns:
str: Path to the model file
"""
if model_name not in self.model_files:
raise ValueError(f"Model {model_name} not found. Available models: {list(self.model_files.keys())}")
model_info = self.model_files[model_name]
model_path = self.models_dir / model_info['file']
# If model doesn't exist locally, download it
if not model_path.exists():
return self.download_from_kaggle(model_name)
return str(model_path)
def list_available_models(self):
"""
List all available models
Returns:
list: List of available model names
"""
return list(self.model_files.keys())