AyobamiMichael commited on
Commit
ff45240
·
verified ·
1 Parent(s): 912dfe7

Upload 6 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -35
  2. .gitignore +52 -0
  3. app.py +528 -0
  4. requirements.txt +8 -0
  5. test_inference.py +259 -0
  6. verify_model.py +33 -0
.gitattributes CHANGED
@@ -1,35 +1 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/ # Python bytecode cache folders
3
+ *.py[cod] # Compiled Python files (.pyc, .pyo, .pyd)
4
+ *$py.class # More compiled files
5
+ *.so # Shared object files (compiled extensions)
6
+ .Python # Python symlink
7
+ env/ # Virtual environment folder
8
+ venv/ # Virtual environment folder (common name)
9
+ ENV/ # Virtual environment folder (another name)
10
+ ecoscanenv/ # EcoScan virtual environment folder
11
+ build/ # Build artifacts
12
+ develop-eggs/ # Development eggs
13
+ dist/ # Distribution files
14
+ downloads/ # Downloaded packages
15
+ eggs/ # Egg files
16
+ .eggs/ # Egg info
17
+ lib/ # Library files
18
+ lib64/ # 64-bit library files
19
+ parts/ # Buildout parts
20
+ sdist/ # Source distribution
21
+ var/ # Variable data
22
+ wheels/ # Wheel files
23
+ *.egg-info/ # Egg metadata
24
+ .installed.cfg # Install config
25
+ *.egg # Egg files
26
+
27
+ # Jupyter Notebook
28
+ .ipynb_checkpoints # Jupyter autosave checkpoints
29
+
30
+ # IDEs
31
+ .vscode/ # VS Code settings
32
+ .idea/ # PyCharm settings
33
+ *.swp # Vim swap files
34
+ *.swo # Vim swap files
35
+ *~ # Backup files
36
+
37
+ # OS
38
+ .DS_Store # Mac Finder metadata
39
+ Thumbs.db # Windows thumbnail cache
40
+
41
+ # Model files (optional - add your trained model to Git LFS)
42
+ *.pth # Uncomment to ignore model files
43
+ *.pt # Uncomment to ignore model files
44
+
45
+ # Gradio cache
46
+ gradio_cached_examples/ # Gradio example cache
47
+ flagged/ # Gradio flagged data
48
+
49
+ # Testing
50
+ .pytest_cache/ # Pytest cache
51
+ .coverage # Coverage reports
52
+ htmlcov/ # HTML coverage reports
app.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EcoScan - AI-Powered Waste Sorting Classifier
3
+ Using Gradio Interface for Deployment
4
+
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import transforms, models
10
+ from PIL import Image
11
+ import gradio as gr
12
+ import numpy as np
13
+ import cv2
14
+ import json
15
+ from pathlib import Path
16
+ from huggingface_hub import hf_hub_download
17
+
18
+
19
+ #
20
+ # CONFIGURATION
21
+ #
22
+
23
+ class Config:
24
+ MODEL_PATH = "model/ecoscan_model.pth"
25
+ CLASS_NAMES_PATH = "model/class_names.json"
26
+ MODEL_NAME = "efficientnet_b3"
27
+ NUM_CLASSES = 6,
28
+ IMAGE_SIZE = 300,
29
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+
31
+ config = Config()
32
+
33
+ # RECYCLING INFORMATION DATABASE
34
+ RECYCLING_INFO = {
35
+ "cardboard":{
36
+ "icon": "📦",
37
+ "tip": "Flatten boxes to save space. Remove any plastic tape or labels. Keep dry - wet cardboard contaminates recycling.",
38
+ "eco_score": 9,
39
+ "decompose_time": "2-3 months",
40
+ "facts": "Recyling 1 ton of cardboard saves 17 trees and 7,000 gallons of water!"
41
+ },
42
+ "glass":{
43
+ "icon": "🍾",
44
+ "tip": "Rinse glass containers to remove food residue. Remove lids and caps, as they are often made of different materials.",
45
+ "eco_score": 8,
46
+ "decompose_time": "1 million years",
47
+ "facts": "Recycling glass saves 30% of the energy required to make new glass from raw materials."
48
+ },
49
+ "metal":{
50
+ "icon": "🔩",
51
+ "tip": "Rinse aluminum cans and steel containers, Crush cans to save space. Metal recyling saves 95% of enerdy!",
52
+ "eco_score": 9,
53
+ "decompose_time": "50-500 years",
54
+ "facts": "Recycling aluminum saves 95% of the energy needed to make new aluminium from raw materials. "
55
+
56
+
57
+ },
58
+ "paper":{
59
+ "icon": "📄",
60
+ "tip": "Keep paper dry and clean. Remove staples and paper clips. Shred sensitive documents before recylcing.",
61
+ "eco_score": 8,
62
+ "decompose_time": "2-6 weeks",
63
+ "facts": "Recycling 1 ton of paper saves 17 trees, 380 gallons of oil, and 7,000 gallons of water."
64
+
65
+ },
66
+ "plastic":{
67
+ "icon": "🧴",
68
+ "tip": "Rinse plastic containers to remove food residue. Check the recycling symbol and number to ensure it's accepted in your local program.",
69
+ "eco_score": 4,
70
+ "decompose_time": "450-1000 years",
71
+ "facts": "Only about 9% of all plastic waste ever produced has been recycled. Recycling plastic saves 88% of the energy compared to producing new plastic from raw materials."
72
+ },
73
+ "trash":{
74
+ "icon": "🗑️",
75
+ "tip": "This item is general waste or e-waste. Check for specialized recylcing programs. Consider composting organic materials",
76
+ "eco_score": 3,
77
+ "decompose_time": "Variable (decades to never)",
78
+ "facts": "E-waste contains valuable materials like gold and copper, but also toxic substances. Always use proper disposal."
79
+ }
80
+
81
+ }
82
+
83
+
84
+ # MODEL LOADING
85
+
86
+ def load_model():
87
+ """Load the trained model"""
88
+ print(f"Loading model on {config.DEVICE}...")
89
+
90
+
91
+ # Download from Hub if not local
92
+ if not Path(config.MODEL_PATH).exists():
93
+ print("Downloading model from Hugging Face Hub...")
94
+ try:
95
+ hf_hub_download(
96
+ repo_id="AyobamiMichael/ecoscan-model",
97
+ filename="ecoscan_model.pth",
98
+ local_dir="model",
99
+ repo_type="model"
100
+ )
101
+ except Exception as e:
102
+ print(f"Error downloading model: {e}")
103
+ raise
104
+
105
+ # Check if model file exists
106
+ if not Path(config.MODEL_PATH).exists():
107
+ raise FileNotFoundError(f"Model file not found:{config.MODEL_PATH}")
108
+
109
+ print(f"Loading complete model from: {config.MODEL_PATH}")
110
+ # Create mode architecture
111
+ if config.MODEL_NAME == "efficientnet_b3":
112
+ from torchvision.models import efficientnet_b3
113
+
114
+ # Load pretrianed model to get correct architecture
115
+ print("Building EfficientNet-B3 architecture...")
116
+ model = efficientnet_b3(weights=None)
117
+
118
+ # Get the input features from the last layer
119
+ in_features = 1536
120
+ num_classes = 6
121
+ print(f"EfficinetNet-B3 classifier input features: {in_features}")
122
+
123
+ # Replace classifier
124
+ model.classifier = nn.Sequential(
125
+ nn.Dropout(p=0.3, inplace=True),
126
+ nn.Linear(in_features, num_classes)
127
+ )
128
+ elif config.MODEL_NAME == "resnet50":
129
+ from torchvision.models import resnet50
130
+
131
+ print("Building ResNet50 architecture...")
132
+ model = resnet50(weights=None)
133
+
134
+ # Get the input features
135
+ in_features = 2048
136
+ num_classes = 6
137
+ print(f"ResNet50 fc input features: {in_features}")
138
+
139
+ # Replace final layer
140
+ model.fc = nn.Linear(in_features,num_classes)
141
+
142
+ else:
143
+ raise ValueError(f"Unknown model: {config.MODEL_NAME}")
144
+
145
+ # Load trained weights
146
+ print(f"Loading weights from: {config.MODEL_PATH}")
147
+ state_dict = torch.load(config.MODEL_PATH, map_location=config.DEVICE)
148
+ try:
149
+ #state_dict = torch.load(config.MODEL_PATH, map_location=config.DEVICE)
150
+ model.load_state_dict(state_dict, strict=True)
151
+ print("✅ All weights loaded successfully!")
152
+ except Exception as e:
153
+ print(f"⚠️ Warning: {e}")
154
+ print("Some weights may not match. Loading with strict=False...")
155
+ model.load_state_dict(state_dict, strict=False)
156
+ print("✅ Weights loaded (partial)")
157
+
158
+ model.to(config.DEVICE)
159
+ model.eval()
160
+
161
+ # Verify the model
162
+ print(f"✅ Model ready on {config.DEVICE}")
163
+ print(f" Input features: {in_features}")
164
+ print(f" Output classes: {config.NUM_CLASSES}")
165
+
166
+ return model
167
+
168
+
169
+ def load_class_names():
170
+ """"Load class names from JSON file"""
171
+ with open(config.CLASS_NAMES_PATH, 'r') as f:
172
+ class_names = json.load(f)
173
+ return class_names
174
+
175
+
176
+ # ============================================================================
177
+ # IMAGE PREPROCESSING
178
+ # ============================================================================
179
+
180
+ def get_transforms():
181
+ """Get image preprocessing transforms"""
182
+ return transforms.Compose([
183
+ transforms.Resize(config.IMAGE_SIZE),
184
+ transforms.CenterCrop(config.IMAGE_SIZE),
185
+ transforms.ToTensor(),
186
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
187
+ std=[0.229, 0.224, 0.225])
188
+ ])
189
+
190
+ # ============================================================================
191
+ # GRAD-CAM VISUALIZATION
192
+ # ============================================================================
193
+
194
+ class GradCAM:
195
+ """"Gradient-weighted Class Activation Mapping"""
196
+
197
+
198
+ def __init__(self, model, target_layer):
199
+ self.model = model
200
+ self.target_layer = target_layer
201
+ self.gradients = None
202
+ self.activations = None
203
+
204
+ # Register hooks
205
+ target_layer.register_forward_hook(self.save_activations)
206
+ target_layer.register_backward_hook(self.save_gradients)
207
+
208
+ def save_activations(self, module, input, output):
209
+ self.activations = output.detach()
210
+
211
+ def save_gradients(self, module, grad_input, grad_output):
212
+ self.gradients = grad_output[0].detach()
213
+
214
+ def generate_cam(self, input_image, class_idx):
215
+ """Generate CAM for a specific class"""
216
+
217
+ try:
218
+ # Forward pass
219
+ output = self.model(input_image)
220
+
221
+ # Backward pass
222
+ self.model.zero_grad()
223
+ class_loss = output[0, class_idx]
224
+ class_loss.backward()
225
+
226
+ # Generate CAM
227
+ if self.gradients is None or self.activations is None:
228
+ print("Warning: gradients or activations not captured")
229
+ return np.ones((input_image.shape[2], input_image.shape[3]))
230
+
231
+ gradients = self.gradients[0] # [C, H, W]
232
+ activations = self.activations[0] # [C, H, W]
233
+
234
+ # Global average pooling on gradients
235
+ weights = torch.mean(gradients, dim=(1, 2)) # [C]
236
+
237
+ # Weighted combination
238
+ cam = torch.zeros(activations.shape[1:], dtype=torch.float32)
239
+ for i, w in enumerate(weights):
240
+ cam += w * activations[i]
241
+
242
+ # ReLU
243
+ cam = torch.relu(cam)
244
+
245
+ # Normalize
246
+ cam_min = cam.min()
247
+ cam_max = cam.max()
248
+ if cam_max - cam_min > 0:
249
+ cam = (cam - cam_min) / (cam_max - cam_min)
250
+ else:
251
+ cam = torch.zeros_like(cam)
252
+
253
+ return cam.cpu().numpy()
254
+
255
+ except Exception as e:
256
+ print(f"Grad-CAM generation error: {e}")
257
+ return np.ones((input_image.shape[2], input_image.shape[3]))
258
+
259
+ def overlay_heatmap(image, heatmap, alpha=0.4):
260
+ """Overlay heatmap on original image"""
261
+
262
+ # Ensure image is numpy array
263
+ if not isinstance(image, np.ndarray):
264
+ image = np.array(image)
265
+
266
+ # Ensure image is uint8
267
+ if image.dtype != np.uint8:
268
+ image = (image * 255).astype(np.uint8)
269
+
270
+ # Resize heatmap to match image
271
+ heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
272
+
273
+ # Apply colormap
274
+ heatmap = np.uint8(255 * heatmap)
275
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
276
+
277
+ # Convert BGR to RGB
278
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
279
+
280
+ # Overlay
281
+ overlay = cv2.addWeighted(image, 1-alpha, heatmap, alpha, 0)
282
+
283
+ return overlay
284
+
285
+ # Global MODELAND CLASS NAMES (will be loaded at startup)
286
+
287
+ model = None
288
+ class_names = None
289
+
290
+ # ============================================================================
291
+ # INFERENCE FUNCTION
292
+ # ============================================================================
293
+
294
+ def classify_image(image):
295
+ """Main classification function """
296
+
297
+ global model, class_names
298
+
299
+ if image is None:
300
+ return None, None, "Please upload an image first!"
301
+
302
+
303
+ # Convert to PIL Image
304
+
305
+ if isinstance(image, np.ndarray):
306
+ pil_image = Image.fromarray(image)
307
+ else:
308
+ pil_image = image
309
+
310
+ # Preprocess
311
+ transform = get_transforms()
312
+ input_tensor = transform(pil_image).unsqueeze(0).to(config.DEVICE)
313
+
314
+ # Get predictions
315
+ with torch.no_grad():
316
+ outputs = model(input_tensor)
317
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
318
+ confidence, predicted = torch.max(probabilities, 1)
319
+
320
+ predicted_class = class_names[predicted.item()]
321
+ confidence_score = confidence.item()
322
+
323
+ # Generate Grad-CAM
324
+
325
+ try:
326
+ # Get traget layer
327
+ if config.MODEL_NAME == "efficientnet_b3":
328
+ target_layer = model.features[-1]
329
+ elif config.MODEL_NAME == "resnet50":
330
+ target_layer = model.layer4[-1]
331
+
332
+ gradcam = GradCAM(model, target_layer)
333
+ cam = gradcam.generate_cam(input_tensor, predicted.item())
334
+
335
+ # Create overlay
336
+ original_img = np.array(pil_image.resize((config.IMAGE_SIZE, config.IMAGE_SIZE)))
337
+ heatmap_img = gradcam.overlay_heatmap(original_img, cam)
338
+ except Exception as e:
339
+ print(f"Grad-CAM error: {e}")
340
+ heatmap_img = np.array(pil_image)
341
+
342
+
343
+ # Get recycling info
344
+ info = RECYCLING_INFO.get(predicted_class, RECYCLING_INFO["trash"])
345
+
346
+ # Format predictions for top-3
347
+ top3_probs, top3_indices = torch.topk(probabilities[0], 3)
348
+ predictions_dict = {}
349
+ for prob, idx in zip(top3_probs, top3_indices):
350
+ class_name = class_names[idx.item()]
351
+ confidence = float(prob.item())
352
+ predictions_dict[class_name] = confidence
353
+
354
+
355
+ # Create detailed output
356
+ # Create detailed output
357
+ output_text = f"""
358
+ ## {info['icon']} Classification Result
359
+
360
+ **Detected Material:** {predicted_class.upper()}
361
+ **Confidence:** {confidence_score*100:.1f}%
362
+
363
+ ---
364
+
365
+ ### ♻️ Recycling Instructions
366
+ {info['tip']}
367
+
368
+ ---
369
+
370
+ ### 📊 Environmental Impact
371
+ - **EcoScore:** {info['eco_score']}/10
372
+ - **Decomposition Time:** {info['decompose_time']}
373
+
374
+ ### 💡 Did You Know?
375
+ {info['facts']}
376
+ """
377
+
378
+ return predictions_dict, heatmap_img, output_text
379
+
380
+
381
+
382
+ # ============================================================================
383
+ # INITIALIZE MODEL & CLASS NAMES AT STARTUP
384
+ # ============================================================================
385
+
386
+ print("🚀 Initializing EcoScan...")
387
+ model = load_model()
388
+ class_names = load_class_names()
389
+ print(f"✅ Loaded {len(class_names)} classes: {class_names}")
390
+ print("🌱 EcoScan ready!")
391
+
392
+
393
+
394
+
395
+ # ============================================================================
396
+ # GRADIO INTERFACE
397
+ # ============================================================================
398
+
399
+ # Custom CSS
400
+ custom_css = """
401
+ #title {
402
+ text-align: center;
403
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
404
+ color: white;
405
+ padding: 20px;
406
+ border-radius: 10px;
407
+ margin-bottom: 20px;
408
+ }
409
+ #output-box {
410
+ border: 2px solid #667eea;
411
+ border-radius: 10px;
412
+ padding: 15px;
413
+ }
414
+ .eco-high { color: #10b981; font-weight: bold; }
415
+ .eco-medium { color: #f59e0b; font-weight: bold; }
416
+ .eco-low { color: #ef4444; font-weight: bold; }
417
+ """
418
+
419
+ # Example images
420
+ examples = [
421
+ ["examples/plastic_bottle.jpg"] if Path("examples/plastic_bottle.jpg").exists() else None,
422
+ ["examples/cardboard_box.jpg"] if Path("examples/cardboard_box.jpg").exists() else None,
423
+ ["examples/glass_jar.jpg"] if Path("examples/glass_jar.jpg").exists() else None,
424
+ ]
425
+ examples = [ex for ex in examples if ex is not None]
426
+
427
+ # Create interface
428
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
429
+
430
+ gr.Markdown(
431
+ """
432
+ <div id="title">
433
+ <h1>🌱 EcoScan - AI Waste Classifier</h1>
434
+ <p>Upload an image of waste material to get instant classification and recycling guidance</p>
435
+ </div>
436
+ """,
437
+ elem_id="title"
438
+ )
439
+
440
+ with gr.Row():
441
+ with gr.Column(scale=1):
442
+ input_image = gr.Image(
443
+ label="📸 Upload Waste Image",
444
+ type="pil",
445
+ height=400
446
+ )
447
+
448
+ classify_btn = gr.Button(
449
+ "🔍 Classify Waste",
450
+ variant="primary",
451
+ size="lg"
452
+ )
453
+
454
+ gr.Markdown(
455
+ """
456
+ ### 📋 Instructions
457
+ 1. Upload a clear image of waste material
458
+ 2. Click "Classify Waste"
459
+ 3. View classification and recycling tips
460
+
461
+ ### 🎯 Supported Categories
462
+ Cardboard • Glass • Metal • Paper • Plastic • General Waste
463
+ """
464
+ )
465
+
466
+ with gr.Column(scale=1):
467
+ with gr.Tab("📊 Results"):
468
+ predictions = gr.Label(
469
+ label="Classification Confidence",
470
+ num_top_classes=3
471
+ )
472
+ recycling_info = gr.Markdown(
473
+ label="Recycling Information",
474
+ elem_id="output-box"
475
+ )
476
+
477
+ with gr.Tab("🔥 AI Visualization"):
478
+ heatmap = gr.Image(
479
+ label="Attention Map (What the AI sees)",
480
+ height=400
481
+ )
482
+ gr.Markdown(
483
+ """
484
+ **Grad-CAM Visualization**: Warmer colors (red/yellow) show regions
485
+ the AI focused on for classification. Cooler colors (blue) indicate
486
+ less important regions.
487
+ """
488
+ )
489
+
490
+ # Examples section
491
+ if examples:
492
+ gr.Examples(
493
+ examples=examples,
494
+ inputs=input_image,
495
+ label="📷 Try These Examples"
496
+ )
497
+
498
+ # Footer
499
+ gr.Markdown(
500
+ """
501
+ ---
502
+ <div style="text-align: center; color: #666;">
503
+ <p>Built with ❤️ for a sustainable future | Powered by EfficientNet-B3 & PyTorch</p>
504
+ <p>💡 <strong>Tip:</strong> This AI model was trained on 2,500+ waste images with 90%+ accuracy</p>
505
+ </div>
506
+ """
507
+ )
508
+
509
+ # Connect button
510
+ classify_btn.click(
511
+ fn=classify_image,
512
+ inputs=input_image,
513
+ outputs=[predictions, heatmap, recycling_info]
514
+ )
515
+
516
+ # ============================================================================
517
+ # LAUNCH
518
+ # ============================================================================
519
+
520
+ if __name__ == "__main__":
521
+ demo.launch(
522
+ server_name="0.0.0.0",
523
+ server_port=7860,
524
+ share=True,
525
+ show_error=True,
526
+ debug=True
527
+
528
+ )
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=5.49.1
4
+ opencv-python-headless>=4.8.0
5
+ Pillow>=10.0.0
6
+ numpy>=1.24.0
7
+ huggingface-hub>=0.16.0
8
+
test_inference.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick inference test script to verify model works before deployment
3
+ Run this before deploying to catch any issues early
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision import transforms, models
9
+ from PIL import Image
10
+ import json
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ def test_model_loading():
15
+ """Test if model loads correctly"""
16
+ print("=" * 60)
17
+ print("🧪 Testing Model Loading...")
18
+ print("=" * 60)
19
+
20
+ try:
21
+ # Check if model file exists
22
+ model_path = "model/ecoscan_model.pth"
23
+ if not Path(model_path).exists():
24
+ print(f"❌ Model file not found: {model_path}")
25
+ print(" Please place your trained model in the model/ folder")
26
+ return False
27
+
28
+ print(f"✅ Found model file: {model_path}")
29
+
30
+ # Check class names
31
+ class_names_path = "model/class_names.json"
32
+ if not Path(class_names_path).exists():
33
+ print(f"❌ Class names file not found: {class_names_path}")
34
+ return False
35
+
36
+ with open(class_names_path, 'r') as f:
37
+ class_names = json.load(f)
38
+
39
+ print(f"✅ Found {len(class_names)} classes: {class_names}")
40
+
41
+ # Load model architecture
42
+ print("\n🏗️ Building model architecture...")
43
+ model = models.efficientnet_b3(weights=None)
44
+ in_features = model.classifier[1].in_features
45
+ model.classifier = nn.Sequential(
46
+ nn.Dropout(p=0.3, inplace=True),
47
+ nn.Linear(in_features, len(class_names))
48
+ )
49
+
50
+ # Load weights
51
+ print("📦 Loading weights...")
52
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53
+ model.load_state_dict(torch.load(model_path, map_location=device))
54
+ model.to(device)
55
+ model.eval()
56
+
57
+ print(f"✅ Model loaded successfully on {device}")
58
+
59
+ return True
60
+
61
+ except Exception as e:
62
+ print(f"❌ Error loading model: {e}")
63
+ import traceback
64
+ traceback.print_exc()
65
+ return False
66
+
67
+ def test_inference():
68
+ """Test inference on a dummy image"""
69
+ print("\n" + "=" * 60)
70
+ print("🔍 Testing Inference...")
71
+ print("=" * 60)
72
+
73
+ try:
74
+ # Load model
75
+ model_path = "model/ecoscan_model.pth"
76
+ class_names_path = "model/class_names.json"
77
+
78
+ with open(class_names_path, 'r') as f:
79
+ class_names = json.load(f)
80
+
81
+ model = models.efficientnet_b3(weights=None)
82
+ in_features = model.classifier[1].in_features
83
+ model.classifier = nn.Sequential(
84
+ nn.Dropout(p=0.3, inplace=True),
85
+ nn.Linear(in_features, len(class_names))
86
+ )
87
+
88
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89
+ model.load_state_dict(torch.load(model_path, map_location=device))
90
+ model.to(device)
91
+ model.eval()
92
+
93
+ # Create dummy image
94
+ print("📸 Creating test image (300x300 RGB)...")
95
+ dummy_image = Image.new('RGB', (300, 300), color='blue')
96
+
97
+ # Preprocess
98
+ transform = transforms.Compose([
99
+ transforms.Resize((300, 300)),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
102
+ std=[0.229, 0.224, 0.225])
103
+ ])
104
+
105
+ input_tensor = transform(dummy_image).unsqueeze(0).to(device)
106
+
107
+ # Run inference
108
+ print("🚀 Running inference...")
109
+ with torch.no_grad():
110
+ outputs = model(input_tensor)
111
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
112
+ confidence, predicted = torch.max(probabilities, 1)
113
+
114
+ predicted_class = class_names[predicted.item()]
115
+ confidence_score = confidence.item()
116
+
117
+ print(f"✅ Inference successful!")
118
+ print(f" Predicted: {predicted_class}")
119
+ print(f" Confidence: {confidence_score*100:.2f}%")
120
+
121
+ # Show top-3 predictions
122
+ print("\n📊 Top-3 Predictions:")
123
+ top3_probs, top3_indices = torch.topk(probabilities[0], min(3, len(class_names)))
124
+ for prob, idx in zip(top3_probs, top3_indices):
125
+ print(f" {class_names[idx.item()]}: {prob.item()*100:.2f}%")
126
+
127
+ return True
128
+
129
+ except Exception as e:
130
+ print(f"❌ Error during inference: {e}")
131
+ import traceback
132
+ traceback.print_exc()
133
+ return False
134
+
135
+ def test_dependencies():
136
+ """Test if all required packages are installed"""
137
+ print("\n" + "=" * 60)
138
+ print("📦 Testing Dependencies...")
139
+ print("=" * 60)
140
+
141
+ required_packages = {
142
+ 'torch': 'PyTorch',
143
+ 'torchvision': 'TorchVision',
144
+ 'PIL': 'Pillow',
145
+ 'gradio': 'Gradio',
146
+ 'cv2': 'OpenCV (cv2)',
147
+ 'numpy': 'NumPy'
148
+ }
149
+
150
+ all_installed = True
151
+
152
+ for package, name in required_packages.items():
153
+ try:
154
+ __import__(package)
155
+ print(f"✅ {name}")
156
+ except ImportError:
157
+ print(f"❌ {name} - NOT INSTALLED")
158
+ all_installed = False
159
+
160
+ return all_installed
161
+
162
+ def test_file_structure():
163
+ """Test if project structure is correct"""
164
+ print("\n" + "=" * 60)
165
+ print("📂 Testing File Structure...")
166
+ print("=" * 60)
167
+
168
+ required_files = [
169
+ "app.py",
170
+ "requirements.txt",
171
+ "README.md",
172
+ "model/ecoscan_model.pth",
173
+ "model/class_names.json"
174
+ ]
175
+
176
+ optional_files = [
177
+ "examples/plastic_bottle.jpg",
178
+ "examples/cardboard_box.jpg",
179
+ "examples/glass_jar.jpg"
180
+ ]
181
+
182
+ all_present = True
183
+
184
+ print("\n🔍 Required files:")
185
+ for file_path in required_files:
186
+ if Path(file_path).exists():
187
+ size = Path(file_path).stat().st_size / (1024 * 1024) # MB
188
+ print(f"✅ {file_path} ({size:.2f} MB)")
189
+ else:
190
+ print(f"❌ {file_path} - MISSING")
191
+ all_present = False
192
+
193
+ print("\n🎨 Optional files:")
194
+ for file_path in optional_files:
195
+ if Path(file_path).exists():
196
+ print(f"✅ {file_path}")
197
+ else:
198
+ print(f"⚠️ {file_path} - not found (optional)")
199
+
200
+ return all_present
201
+
202
+ def main():
203
+ """Run all tests"""
204
+ print("\n")
205
+ print("╔" + "=" * 58 + "╗")
206
+ print("║" + " " * 58 + "║")
207
+ print("║" + " 🌱 EcoScan - Pre-Deployment Testing Suite ".center(58) + "║")
208
+ print("║" + " " * 58 + "║")
209
+ print("╚" + "=" * 58 + "╝")
210
+ print("\n")
211
+
212
+ tests = [
213
+ ("File Structure", test_file_structure),
214
+ ("Dependencies", test_dependencies),
215
+ ("Model Loading", test_model_loading),
216
+ ("Inference", test_inference)
217
+ ]
218
+
219
+ results = {}
220
+
221
+ for test_name, test_func in tests:
222
+ try:
223
+ results[test_name] = test_func()
224
+ except Exception as e:
225
+ print(f"\n❌ Test '{test_name}' crashed: {e}")
226
+ results[test_name] = False
227
+
228
+ # Summary
229
+ print("\n" + "=" * 60)
230
+ print("📋 TEST SUMMARY")
231
+ print("=" * 60)
232
+
233
+ for test_name, passed in results.items():
234
+ status = "✅ PASSED" if passed else "❌ FAILED"
235
+ print(f"{test_name:.<40} {status}")
236
+
237
+ all_passed = all(results.values())
238
+
239
+ print("\n" + "=" * 60)
240
+ if all_passed:
241
+ print("🎉 ALL TESTS PASSED!")
242
+ print("✅ Your app is ready for deployment!")
243
+ print("\nNext steps:")
244
+ print(" 1. Test locally: python app.py")
245
+ print(" 2. Deploy to Hugging Face Spaces")
246
+ print(" 3. Share with the world! 🌍")
247
+ else:
248
+ print("⚠️ SOME TESTS FAILED")
249
+ print("Please fix the issues above before deploying.")
250
+ print("\nCommon fixes:")
251
+ print(" - Install missing packages: pip install -r requirements.txt")
252
+ print(" - Download model from Kaggle to model/ folder")
253
+ print(" - Verify file paths match your structure")
254
+ print("=" * 60 + "\n")
255
+
256
+ return 0 if all_passed else 1
257
+
258
+ if __name__ == "__main__":
259
+ sys.exit(main())
verify_model.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+
4
+ # Check model file
5
+ model_path = "model/ecoscan_model.pth"
6
+ print(f"Model exists: {Path(model_path).exists()}")
7
+
8
+ # Load and inspect
9
+ if Path(model_path).exists():
10
+ checkpoint = torch.load(model_path, map_location='cpu')
11
+
12
+ print(f"\nModel info:")
13
+ print(f"Type: {type(checkpoint)}")
14
+
15
+ if isinstance(checkpoint, dict):
16
+ print(f"Keys: {checkpoint.keys()}")
17
+ if 'state_dict' in checkpoint:
18
+ state_dict = checkpoint['state_dict']
19
+ else:
20
+ state_dict = checkpoint
21
+ else:
22
+ state_dict = checkpoint
23
+
24
+ # Check shapes
25
+ print(f"\nLayer shapes:")
26
+ for key, value in list(state_dict.items())[:5]:
27
+ print(f" {key}: {value.shape}")
28
+
29
+ # Check classifier
30
+ if 'classifier.1.weight' in state_dict:
31
+ weight = state_dict['classifier.1.weight']
32
+ print(f"\nClassifier output: {weight.shape[0]} classes")
33
+ print(f"Classifier input: {weight.shape[1]} features")