harshinde commited on
Commit
22c4819
·
verified ·
1 Parent(s): 10968f3

Upload src/model_downloader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/model_downloader.py +142 -0
src/model_downloader.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import streamlit as st
5
+ from pathlib import Path
6
+ from tqdm.auto import tqdm
7
+
8
+ class ModelDownloader:
9
+ def __init__(self):
10
+ # Create models directory for caching
11
+ self.models_dir = Path("/app/models")
12
+ self.models_dir.mkdir(exist_ok=True)
13
+
14
+ # Kaggle model repository details
15
+ self.kaggle_model_url = "https://www.kaggle.com/models/harshshinde8/sims/frameworks/PyTorch/serve"
16
+
17
+ # Model mapping with Kaggle model IDs
18
+ self.model_files = {
19
+ "deeplabv3plus": {
20
+ "file": "deeplabv3.pth",
21
+ "id": "deeplabv3"
22
+ },
23
+ "densenet121": {
24
+ "file": "densenet121.pth",
25
+ "id": "densenet121"
26
+ },
27
+ "efficientnetb0": {
28
+ "file": "effucientnetb0.pth",
29
+ "id": "effucientnetb0"
30
+ },
31
+ "inceptionresnetv2": {
32
+ "file": "inceptionresnetv2.pth",
33
+ "id": "inceptionresnetv2"
34
+ },
35
+ "inceptionv4": {
36
+ "file": "inceptionv4.pth",
37
+ "id": "inceptionv4"
38
+ },
39
+ "mitb1": {
40
+ "file": "mitb1.pth",
41
+ "id": "mitb1"
42
+ },
43
+ "mobilenetv2": {
44
+ "file": "mobilenetv2.pth",
45
+ "id": "mobilenetv2"
46
+ },
47
+ "resnet34": {
48
+ "file": "resnet34.pth",
49
+ "id": "resnet34"
50
+ },
51
+ "resnext50_32x4d": {
52
+ "file": "resnext50-32x4d.pth",
53
+ "id": "resnext50-32x4d"
54
+ },
55
+ "se_resnet50": {
56
+ "file": "se_resnet50.pth",
57
+ "id": "se_resnet50"
58
+ },
59
+ "se_resnext50_32x4d": {
60
+ "file": "se_resnext50_32x4d.pth",
61
+ "id": "se_resnext50_32x4d"
62
+ },
63
+ "segformer": {
64
+ "file": "segformer.pth",
65
+ "id": "segformer"
66
+ },
67
+ "vgg16": {
68
+ "file": "vgg16.pth",
69
+ "id": "vgg16"
70
+ }
71
+ }
72
+
73
+ def download_from_kaggle(self, model_name):
74
+ """
75
+ Download model from Kaggle Models repository
76
+ Args:
77
+ model_name (str): Name of the model to download
78
+ Returns:
79
+ str: Path to the downloaded model file
80
+ """
81
+ if model_name not in self.model_files:
82
+ raise ValueError(f"Model {model_name} not found. Available models: {list(self.model_files.keys())}")
83
+
84
+ model_info = self.model_files[model_name]
85
+ model_path = self.models_dir / model_info['file']
86
+
87
+ # If model already exists, return path
88
+ if model_path.exists():
89
+ return str(model_path)
90
+
91
+ # Construct download URL for the specific model
92
+ download_url = f"{self.kaggle_model_url}/{model_info['id']}/1"
93
+
94
+ try:
95
+ st.info(f"Downloading model {model_name} from Kaggle Models...")
96
+ progress_bar = st.progress(0)
97
+
98
+ # Download with progress
99
+ response = requests.get(download_url, stream=True)
100
+ response.raise_for_status()
101
+
102
+ total_size = int(response.headers.get('content-length', 0))
103
+ block_size = 1024 # 1 Kibibyte
104
+
105
+ with open(model_path, 'wb') as f:
106
+ for i, data in enumerate(response.iter_content(block_size)):
107
+ progress_bar.progress(min(i * block_size / total_size, 1.0))
108
+ f.write(data)
109
+
110
+ st.success(f"Successfully downloaded {model_name}")
111
+ return str(model_path)
112
+
113
+ except requests.exceptions.RequestException as e:
114
+ raise Exception(f"Failed to download model from Kaggle: {str(e)}")
115
+
116
+ def get_model_path(self, model_name):
117
+ """
118
+ Get the path for a model file, downloading it from Kaggle if necessary
119
+ Args:
120
+ model_name (str): Name of the model (e.g., 'deeplabv3plus', 'densenet121', etc.)
121
+ Returns:
122
+ str: Path to the model file
123
+ """
124
+ if model_name not in self.model_files:
125
+ raise ValueError(f"Model {model_name} not found. Available models: {list(self.model_files.keys())}")
126
+
127
+ model_info = self.model_files[model_name]
128
+ model_path = self.models_dir / model_info['file']
129
+
130
+ # If model doesn't exist locally, download it
131
+ if not model_path.exists():
132
+ return self.download_from_kaggle(model_name)
133
+
134
+ return str(model_path)
135
+
136
+ def list_available_models(self):
137
+ """
138
+ List all available models
139
+ Returns:
140
+ list: List of available model names
141
+ """
142
+ return list(self.model_files.keys())