Abs6187 commited on
Commit
7a97282
·
verified ·
1 Parent(s): 1752e5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -119
app.py CHANGED
@@ -1,154 +1,178 @@
1
  import gradio as gr
2
  import torch
 
 
3
  import numpy as np
 
 
4
  import cv2
5
- from ultralytics import YOLO
6
- from PIL import Image
7
- import requests
8
- import json
9
 
10
- # Download sample images
11
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-1.jpg?raw=true', 'sample_1.jpg')
12
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-2.jpg?raw=true', 'sample_2.jpg')
13
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-3.jpg?raw=true', 'sample_3.jpg')
14
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-4.jpg?raw=true', 'sample_4.jpg')
15
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-5.jpg?raw=true', 'sample_5.jpg')
16
 
 
17
  model = YOLO("best.pt")
 
18
 
19
- GEMINI_API_KEY = "AIzaSyCBs4TumAonKI0AodIzbl4b8Vmu9eM_r9I"
20
- GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
21
-
22
- def get_gemini_response(prompt: str) -> str:
23
- headers = {
24
- 'Content-Type': 'application/json',
25
- }
26
-
27
- data = {
28
- "contents": [{"parts": [{"text": prompt}]}]
29
- }
30
-
31
- try:
32
- response = requests.post(
33
- f"{GEMINI_API_URL}?key={GEMINI_API_KEY}",
34
- headers=headers,
35
- json=data
36
- )
37
- response.raise_for_status()
38
- result = response.json()
39
- return result['candidates'][0]['content']['parts'][0]['text']
40
- except Exception as e:
41
- print(f"Error calling Gemini API: {str(e)}")
42
- return "Unable to get analysis at the moment."
43
-
44
- def get_safety_analysis(stats: dict) -> str:
45
- prompt = f"""
46
- You are a traffic safety analyst. Analyze the following statistics and provide a brief safety report:
47
-
48
- Detection Results:
49
- - Total Detections: {stats.get('total_detections', 0)}
50
- - Riders with Helmet: {stats.get('with_helmet', 0)}
51
- - Riders without Helmet: {stats.get('without_helmet', 0)}
52
- - Helmet Compliance Rate: {stats.get('helmet_compliance', 0)}%
53
- - License Plates Detected: {stats.get('license_plates', 0)}
54
-
55
- Provide a 3-4 sentence safety analysis focusing on helmet compliance and potential safety concerns.
56
- """
57
- return get_gemini_response(prompt)
58
-
59
- def yoloV8_func(image=None, image_size=640, conf_threshold=0.4, iou_threshold=0.5):
60
- print(f"Received image_size: {image_size}")
61
-
62
  if image_size is None:
63
  image_size = 640
64
 
 
65
  if not isinstance(image_size, int):
66
  image_size = int(image_size)
67
 
 
68
  imgsz = [image_size, image_size]
69
 
70
- results = model.predict(
71
- source=image,
72
- conf=conf_threshold,
73
- iou=iou_threshold,
74
- imgsz=imgsz,
75
- verbose=False
76
- )
 
 
77
 
78
- boxes = results[0].boxes.xyxy.cpu().numpy()
79
- class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
 
80
 
81
- total_riders = int(sum((class_ids == 0) | (class_ids == 1)))
82
- helmet_compliance = 0 if total_riders == 0 else int(sum(class_ids == 0) / total_riders * 100)
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- stats = {
85
- 'total_detections': len(boxes),
86
- 'with_helmet': int(sum(class_ids == 0)),
87
- 'without_helmet': int(sum(class_ids == 1)),
88
- 'license_plates': int(sum(class_ids == 2)),
89
- 'helmet_compliance': helmet_compliance,
90
- 'total_riders': total_riders
91
- }
92
 
93
- safety_analysis = get_safety_analysis(stats)
94
- print("\nSafety Analysis:", safety_analysis)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- annotated_image = results[0].plot()
 
97
 
98
- return annotated_image
99
 
100
- with gr.Blocks() as demo:
101
- with gr.Row():
102
- image_input = gr.Image(type="filepath", label="Input Image")
103
- output_image = gr.Image(type="pil", label="Output Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  with gr.Row():
106
- image_size = gr.Slider(
107
- minimum=320,
108
- maximum=1280,
109
- value=640,
110
- step=32,
111
- label="Image Size",
112
- interactive=True
113
- )
114
- conf_threshold = gr.Slider(
115
- minimum=0.1,
116
- maximum=1.0,
117
- value=0.4,
118
- step=0.05,
119
- label="Confidence Threshold",
120
- interactive=True
121
- )
122
- iou_threshold = gr.Slider(
123
- minimum=0.1,
124
- maximum=1.0,
125
- value=0.5,
126
- step=0.05,
127
- label="IOU Threshold",
128
- interactive=True
129
- )
130
 
131
- process_btn = gr.Button("Process Image")
132
- process_btn.click(
 
 
 
 
133
  fn=yoloV8_func,
134
- inputs=[image_input, image_size, conf_threshold, iou_threshold],
135
- outputs=output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
-
138
- outputs = gr.Image(type="pil", label="Output Image")
139
-
140
- title = "YOLOv11 Motorcyclist Helmet Detection"
141
- description = """
142
- This application uses YOLOv11 to detect Motorcyclists with and without Helmets in images.
143
- Upload an image, adjust the confidence and IOU thresholds, and view the detection results.
144
- You can customize the model's performance to fit your needs.
145
- """
146
- article = """
147
- <h2>How It Works:</h2>
148
- <p>This model detects Motorcyclists with and without Helmets in images and highlights them with bounding boxes.
149
- Adjust the confidence threshold to control detection accuracy and the IOU threshold for overlap sensitivity.</p>
150
- <p>Upload your images and try it out!</p>
151
- """
152
 
153
  if __name__ == "__main__":
154
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
+ from ultralytics import YOLO
4
+ from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
+ import pandas as pd
7
+ import os
8
  import cv2
9
+ import time
 
 
 
10
 
11
+ # Download sample images (optional)
12
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-1.jpg?raw=true', 'sample_1.jpg')
13
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-2.jpg?raw=true', 'sample_2.jpg')
14
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-3.jpg?raw=true', 'sample_3.jpg')
15
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-4.jpg?raw=true', 'sample_4.jpg')
16
  torch.hub.download_url_to_file('https://github.com/Janno1402/Helmet-License-Plate-Detection/blob/main/Sample-Image-5.jpg?raw=true', 'sample_5.jpg')
17
 
18
+ # Load model (cached for performance)
19
  model = YOLO("best.pt")
20
+ class_names = {0: 'With Helmet', 1: 'Without Helmet', 2: 'License Plate'}
21
 
22
+ def yoloV8_func(
23
+ image=None,
24
+ image_size=640,
25
+ conf_threshold=0.4,
26
+ iou_threshold=0.5,
27
+ show_stats=True,
28
+ show_confidence=True
29
+ ):
30
+ # Handle NoneType for image_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if image_size is None:
32
  image_size = 640
33
 
34
+ # Ensure image_size is an integer
35
  if not isinstance(image_size, int):
36
  image_size = int(image_size)
37
 
38
+ # Construct imgsz as a list of two integers [width, height]
39
  imgsz = [image_size, image_size]
40
 
41
+ # Make predictions
42
+ results = model.predict(image, conf=conf_threshold, iou=iou_threshold, imgsz=imgsz)
43
+
44
+ # Get the output image with bounding boxes
45
+ annotated_image = results[0].plot() # This returns a PIL image
46
+
47
+ # Convert to PIL if it's a numpy array
48
+ if isinstance(annotated_image, np.ndarray):
49
+ annotated_image = Image.fromarray(cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB))
50
 
51
+ # Extract detection information
52
+ boxes = results[0].boxes
53
+ detections = []
54
 
55
+ if boxes is not None and len(boxes) > 0:
56
+ for i, (box, cls, conf) in enumerate(zip(boxes.xyxy, boxes.cls, boxes.conf)):
57
+ x1, y1, x2, y2 = box.tolist()
58
+ class_id = int(cls)
59
+ confidence = float(conf)
60
+ label = class_names.get(class_id, f"Class {class_id}")
61
+
62
+ detections.append({
63
+ "Object": label,
64
+ "Confidence": f"{confidence:.2f}",
65
+ "Position": f"({int(x1)}, {int(y1)})",
66
+ "Dimensions": f"{int(x2-x1)}x{int(y2-y1)}"
67
+ })
68
 
69
+ # Create stats text
70
+ stats_text = ""
71
+ if show_stats and detections:
72
+ df = pd.DataFrame(detections)
73
+ counts = df['Object'].value_counts().to_dict()
74
+ stats_text = "Detection Summary:\n"
75
+ for obj, count in counts.items():
76
+ stats_text += f"- {obj}: {count}\n"
77
 
78
+ # Add stats to image if requested
79
+ if show_stats and stats_text:
80
+ draw = ImageDraw.Draw(annotated_image)
81
+ try:
82
+ font = ImageFont.truetype("arial.ttf", 20)
83
+ except:
84
+ font = ImageFont.load_default()
85
+
86
+ # Add semi-transparent background for text
87
+ text_bbox = draw.textbbox((0, 0), stats_text, font=font)
88
+ text_width = text_bbox[2] - text_bbox[0]
89
+ text_height = text_bbox[3] - text_bbox[1]
90
+ draw.rectangle([10, 10, 20 + text_width, 20 + text_height], fill=(0, 0, 0, 128))
91
+
92
+ # Add text
93
+ draw.text((15, 15), stats_text, font=font, fill=(255, 255, 255))
94
 
95
+ # Create a detection table for display
96
+ detection_table = pd.DataFrame(detections) if detections else pd.DataFrame(columns=["Object", "Confidence", "Position", "Dimensions"])
97
 
98
+ return annotated_image, detection_table, stats_text
99
 
100
+ # Define custom CSS for styling
101
+ custom_css = """
102
+ #title { text-align: center; }
103
+ #description { text-align: center; }
104
+ .footer {
105
+ text-align: center;
106
+ margin-top: 20px;
107
+ color: #666;
108
+ }
109
+ .important { font-weight: bold; color: red; }
110
+ """
111
+
112
+ # Set up Gradio interface with Blocks for more control
113
+ with gr.Blocks(css=custom_css, title="YOLOv11 Motorcyclist Helmet Detection") as demo:
114
+ gr.HTML("<h1 id='title'>YOLOv11 Motorcyclist Helmet Detection</h1>")
115
+ gr.HTML("""
116
+ <div id='description'>
117
+ <p>This application uses YOLOv11 to detect Motorcyclists with and without Helmets in images.</p>
118
+ <p>Upload an image, adjust the parameters, and view the detection results with detailed statistics.</p>
119
+ </div>
120
+ """)
121
 
122
  with gr.Row():
123
+ with gr.Column(scale=1):
124
+ gr.Markdown("### Input Parameters")
125
+ input_image = gr.Image(type="filepath", label="Input Image", sources=["upload", "webcam"])
126
+ with gr.Row():
127
+ image_size = gr.Slider(minimum=320, maximum=1280, value=640, step=32, label="Image Size")
128
+ conf_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
129
+ with gr.Row():
130
+ iou_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="IOU Threshold")
131
+ show_stats = gr.Checkbox(value=True, label="Show Statistics on Image")
132
+
133
+ submit_btn = gr.Button("Detect Objects", variant="primary")
134
+ clear_btn = gr.Button("Clear")
135
+
136
+ with gr.Column(scale=2):
137
+ gr.Markdown("### Output Results")
138
+ output_image = gr.Image(type="pil", label="Output Image")
139
+ output_table = gr.Dataframe(
140
+ headers=["Object", "Confidence", "Position", "Dimensions"],
141
+ label="Detection Details",
142
+ interactive=False
143
+ )
144
+ output_stats = gr.Textbox(label="Detection Summary", interactive=False)
 
 
145
 
146
+ # Examples
147
+ gr.Markdown("### Example Images")
148
+ gr.Examples(
149
+ examples=[["sample_1.jpg"], ["sample_2.jpg"], ["sample_3.jpg"], ["sample_4.jpg"], ["sample_5.jpg"]],
150
+ inputs=input_image,
151
+ outputs=[output_image, output_table, output_stats],
152
  fn=yoloV8_func,
153
+ cache_examples=True,
154
+ )
155
+
156
+ # Footer
157
+ gr.HTML("""
158
+ <div class='footer'>
159
+ <p>Built with Gradio and Ultralytics YOLO</p>
160
+ <p>Note: This is a demonstration application. Detection accuracy may vary based on image quality and conditions.</p>
161
+ </div>
162
+ """)
163
+
164
+ # Button actions
165
+ submit_btn.click(
166
+ fn=yoloV8_func,
167
+ inputs=[input_image, image_size, conf_threshold, iou_threshold, show_stats],
168
+ outputs=[output_image, output_table, output_stats]
169
+ )
170
+
171
+ clear_btn.click(
172
+ fn=lambda: [None, None, None],
173
+ inputs=[],
174
+ outputs=[input_image, output_image, output_table, output_stats]
175
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  if __name__ == "__main__":
178
+ demo.launch(debug=True, share=True)