Simon9 commited on
Commit
6a9d119
ยท
verified ยท
1 Parent(s): 3bb84e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +533 -314
app.py CHANGED
@@ -45,14 +45,15 @@ print(f"๐Ÿ–ฅ๏ธ Using device: {DEVICE}")
45
  # ==============================================
46
  CLIENT = InferenceHTTPClient(
47
  api_url="https://detect.roboflow.com",
48
- api_key=ROBOFLOW_API_KEY
49
  )
50
 
51
  PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
52
  FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
53
 
 
54
  def infer_with_confidence(model_id: str, frame: np.ndarray, confidence_threshold: float = 0.3):
55
- """Run inference and filter by confidence threshold"""
56
  result = CLIENT.infer(frame, model_id=model_id)
57
  detections = sv.Detections.from_inference(result)
58
  # Filter by confidence
@@ -60,6 +61,7 @@ def infer_with_confidence(model_id: str, frame: np.ndarray, confidence_threshold
60
  detections = detections[detections.confidence > confidence_threshold]
61
  return result, detections
62
 
 
63
  # ==============================================
64
  # SIGLIP MODEL (Embeddings)
65
  # ==============================================
@@ -74,7 +76,6 @@ CONFIG = SoccerPitchConfiguration()
74
 
75
  # ==============================================
76
  # TABLE HEADERS FOR GRADIO DATAFRAMES
77
- # (IMPORTANT: col_count MUST equal len(headers))
78
  # ==============================================
79
  PLAYER_STATS_HEADERS = [
80
  "Player ID",
@@ -107,9 +108,9 @@ EVENT_HEADERS = [
107
  # ==============================================
108
  def replace_outliers_based_on_distance(
109
  positions: List[np.ndarray],
110
- distance_threshold: float
111
  ) -> List[np.ndarray]:
112
- """Remove outlier positions based on distance threshold"""
113
  last_valid_position: Union[np.ndarray, None] = None
114
  cleaned_positions: List[np.ndarray] = []
115
 
@@ -130,6 +131,7 @@ def replace_outliers_based_on_distance(
130
 
131
  return cleaned_positions
132
 
 
133
  # ==============================================
134
  # PITCH DISTANCE (UNITS FIX: meters)
135
  # ==============================================
@@ -149,37 +151,40 @@ def pitch_distance_m(p1: np.ndarray, p2: np.ndarray) -> float:
149
  else:
150
  return d
151
 
 
152
  # ==============================================
153
  # PLAYER PERFORMANCE TRACKING
154
  # ==============================================
155
  class PlayerPerformanceTracker:
156
- """Track individual player performance metrics and generate heatmaps"""
157
-
158
  def __init__(self, pitch_config, fps: float = 30.0):
159
  self.config = pitch_config
160
  self.fps = fps
161
  self.player_positions = defaultdict(list)
162
- self.player_velocities = defaultdict(list) # km/h
163
- self.player_distances = defaultdict(float) # meters
164
  self.player_team = {}
165
- self.player_stats = defaultdict(lambda: {
166
- 'frames_visible': 0,
167
- 'avg_velocity': 0.0, # km/h
168
- 'max_velocity': 0.0, # km/h
169
- 'time_in_attacking_third': 0,
170
- 'time_in_defensive_third': 0,
171
- 'time_in_middle_third': 0
172
- })
173
-
 
 
174
  def update(self, tracker_id: int, position: np.ndarray, team_id: int, frame: int):
175
- """Update player position and calculate metrics"""
176
  if len(position) != 2:
177
  return
178
-
179
  self.player_team[tracker_id] = team_id
180
  self.player_positions[tracker_id].append((position[0], position[1], frame))
181
- self.player_stats[tracker_id]['frames_visible'] += 1
182
-
183
  if len(self.player_positions[tracker_id]) > 1:
184
  prev_pos = np.array(self.player_positions[tracker_id][-2][:2], dtype=float)
185
  curr_pos = np.array(position, dtype=float)
@@ -187,254 +192,309 @@ class PlayerPerformanceTracker:
187
  # distance in meters between frames
188
  distance_m = pitch_distance_m(prev_pos, curr_pos)
189
  self.player_distances[tracker_id] += distance_m
190
-
191
  # speed in km/h
192
  speed_mps = distance_m * self.fps
193
  speed_kmh = speed_mps * 3.6
194
  self.player_velocities[tracker_id].append(speed_kmh)
195
-
196
- if speed_kmh > self.player_stats[tracker_id]['max_velocity']:
197
- self.player_stats[tracker_id]['max_velocity'] = speed_kmh
198
-
199
  pitch_length = self.config.length
200
  if position[0] < pitch_length / 3:
201
- self.player_stats[tracker_id]['time_in_defensive_third'] += 1
202
  elif position[0] < 2 * pitch_length / 3:
203
- self.player_stats[tracker_id]['time_in_middle_third'] += 1
204
  else:
205
- self.player_stats[tracker_id]['time_in_attacking_third'] += 1
206
-
207
  def get_player_stats(self, tracker_id: int) -> dict:
208
- """Get comprehensive stats for a player"""
209
  stats = self.player_stats[tracker_id].copy()
210
-
211
  if len(self.player_velocities[tracker_id]) > 0:
212
- stats['avg_velocity'] = float(np.mean(self.player_velocities[tracker_id]))
213
-
214
- stats['total_distance_meters'] = float(self.player_distances[tracker_id])
215
- stats['team_id'] = int(self.player_team.get(tracker_id, -1))
216
-
217
  return stats
218
-
219
  def generate_heatmap(self, tracker_id: int, resolution: int = 100) -> np.ndarray:
220
- """Generate heatmap for a specific player"""
221
  if tracker_id not in self.player_positions or len(self.player_positions[tracker_id]) == 0:
222
  return np.zeros((resolution, resolution))
223
-
224
  positions = np.array([(x, y) for x, y, _ in self.player_positions[tracker_id]])
225
-
226
  pitch_length = self.config.length
227
  pitch_width = self.config.width
228
-
229
  heatmap, xedges, yedges = np.histogram2d(
230
- positions[:, 0], positions[:, 1],
 
231
  bins=[resolution, resolution],
232
- range=[[0, pitch_length], [0, pitch_width]]
233
  )
234
-
235
  heatmap = gaussian_filter(heatmap, sigma=3)
236
-
237
  return heatmap.T
238
-
239
  def get_all_players_by_team(self) -> Dict[int, List[int]]:
240
- """Get all player IDs grouped by team"""
241
  teams = defaultdict(list)
242
  for tracker_id, team_id in self.player_team.items():
243
  teams[team_id].append(tracker_id)
244
  return teams
245
 
 
246
  # ==============================================
247
  # TRACKING MANAGER
248
  # ==============================================
249
  class PlayerTrackingManager:
250
- """Manages persistent player tracking with team assignment stability"""
251
-
252
  def __init__(self, max_history=10):
253
  self.tracker_team_history: Dict[int, List[int]] = defaultdict(list)
254
  self.max_history = max_history
255
  self.active_trackers = set()
256
-
257
  def update_team_assignment(self, tracker_id: int, team_id: int):
258
- """Store team assignment history for each tracker"""
259
  self.tracker_team_history[tracker_id].append(team_id)
260
  if len(self.tracker_team_history[tracker_id]) > self.max_history:
261
  self.tracker_team_history[tracker_id].pop(0)
262
  self.active_trackers.add(tracker_id)
263
-
264
  def get_stable_team_id(self, tracker_id: int, current_team_id: int) -> int:
265
- """Get stable team ID using majority voting from history"""
266
  if tracker_id not in self.tracker_team_history or len(self.tracker_team_history[tracker_id]) < 3:
267
  return current_team_id
268
-
269
  history = self.tracker_team_history[tracker_id]
270
  team_counts = np.bincount(history)
271
  stable_team = int(np.argmax(team_counts))
272
  return stable_team
273
-
274
  def get_player_count_by_team(self) -> Dict[int, int]:
275
- """Get current count of players per team"""
276
  team_counts = defaultdict(int)
277
  for tracker_id in self.active_trackers:
278
  if tracker_id in self.tracker_team_history and len(self.tracker_team_history[tracker_id]) > 0:
279
- stable_team = self.get_stable_team_id(tracker_id, self.tracker_team_history[tracker_id][-1])
 
 
 
280
  team_counts[stable_team] += 1
281
  return team_counts
282
-
283
  def reset_frame(self):
284
- """Reset active trackers for new frame"""
285
  self.active_trackers = set()
286
 
 
287
  # ==============================================
288
  # VISUALIZATION FUNCTIONS
289
  # ==============================================
290
- def create_player_heatmap_visualization(performance_tracker: PlayerPerformanceTracker,
291
- tracker_id: int) -> np.ndarray:
292
- """Create a single player heatmap overlay on pitch"""
 
 
293
  pitch = draw_pitch(CONFIG)
294
  heatmap = performance_tracker.generate_heatmap(tracker_id, resolution=150)
295
-
296
  if heatmap.max() > 0:
297
  heatmap = heatmap / heatmap.max()
298
-
299
  padding = 50
300
-
301
  pitch_height, pitch_width = pitch.shape[:2]
302
- heatmap_resized = cv2.resize(heatmap, (pitch_width - 2*padding, pitch_height - 2*padding))
303
-
304
  heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
305
-
306
  overlay = pitch.copy()
307
- overlay[padding:pitch_height-padding, padding:pitch_width-padding] = heatmap_colored
308
-
309
  result = cv2.addWeighted(pitch, 0.6, overlay, 0.4, 0)
310
-
311
  stats = performance_tracker.get_player_stats(tracker_id)
312
- team_color = "Blue" if stats['team_id'] == 0 else "Pink"
313
-
314
  text_lines = [
315
  f"Player #{tracker_id} ({team_color} Team)",
316
  f"Distance: {stats['total_distance_meters']:.1f} m",
317
  f"Avg Speed: {stats['avg_velocity']:.2f} km/h",
318
  f"Max Speed: {stats['max_velocity']:.2f} km/h",
319
- f"Frames: {stats['frames_visible']}"
320
  ]
321
-
322
  y_offset = 30
323
  for line in text_lines:
324
- cv2.putText(result, line, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX,
325
- 0.6, (255, 255, 255), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
 
326
  y_offset += 25
327
-
328
  return result
329
 
330
 
331
  def create_team_comparison_plot(performance_tracker: PlayerPerformanceTracker) -> go.Figure:
332
- """Create interactive performance comparison plots"""
333
  teams = performance_tracker.get_all_players_by_team()
334
-
335
  fig = make_subplots(
336
- rows=2, cols=2,
337
- subplot_titles=('Distance Covered', 'Average Speed', 'Max Speed', 'Activity by Zone'),
338
- specs=[[{'type': 'bar'}, {'type': 'bar'}],
339
- [{'type': 'bar'}, {'type': 'bar'}]]
 
 
 
 
 
340
  )
341
-
342
- colors = {0: '#00BFFF', 1: '#FF1493'}
343
- team_names = {0: 'Team 0 (Blue)', 1: 'Team 1 (Pink)'}
344
-
345
  for team_id, player_ids in teams.items():
346
  if team_id not in [0, 1]:
347
  continue
348
-
349
  distances = []
350
  avg_speeds = []
351
  max_speeds = []
352
  attacking_time = []
353
-
354
  for pid in player_ids:
355
  stats = performance_tracker.get_player_stats(pid)
356
- distances.append(stats['total_distance_meters'])
357
- avg_speeds.append(stats['avg_velocity']) # km/h
358
- max_speeds.append(stats['max_velocity']) # km/h
359
- attacking_time.append(stats['time_in_attacking_third'])
360
-
361
  player_labels = [f"#{pid}" for pid in player_ids]
362
-
363
  fig.add_trace(
364
- go.Bar(x=player_labels, y=distances, name=team_names[team_id],
365
- marker_color=colors[team_id], showlegend=True),
366
- row=1, col=1
 
 
 
 
 
 
367
  )
368
-
369
  fig.add_trace(
370
- go.Bar(x=player_labels, y=avg_speeds, name=team_names[team_id],
371
- marker_color=colors[team_id], showlegend=False),
372
- row=1, col=2
 
 
 
 
 
 
373
  )
374
-
375
  fig.add_trace(
376
- go.Bar(x=player_labels, y=max_speeds, name=team_names[team_id],
377
- marker_color=colors[team_id], showlegend=False),
378
- row=2, col=1
 
 
 
 
 
 
379
  )
380
-
381
  fig.add_trace(
382
- go.Bar(x=player_labels, y=attacking_time, name=team_names[team_id],
383
- marker_color=colors[team_id], showlegend=False),
384
- row=2, col=2
 
 
 
 
 
 
385
  )
386
-
387
  fig.update_xaxes(title_text="Players", row=1, col=1)
388
  fig.update_xaxes(title_text="Players", row=1, col=2)
389
  fig.update_xaxes(title_text="Players", row=2, col=1)
390
  fig.update_xaxes(title_text="Players", row=2, col=2)
391
-
392
  fig.update_yaxes(title_text="Distance (m)", row=1, col=1)
393
  fig.update_yaxes(title_text="Speed (km/h)", row=1, col=2)
394
  fig.update_yaxes(title_text="Speed (km/h)", row=2, col=1)
395
  fig.update_yaxes(title_text="Frames in Zone", row=2, col=2)
396
-
397
- fig.update_layout(height=800, title_text="Team Performance Comparison", barmode='group')
398
-
399
  return fig
400
 
401
 
402
  def create_combined_heatmaps(performance_tracker: PlayerPerformanceTracker) -> np.ndarray:
403
- """Create side-by-side team heatmaps"""
404
  teams = performance_tracker.get_all_players_by_team()
405
-
406
  team_heatmaps = []
407
  for team_id in [0, 1]:
408
  if team_id not in teams:
409
  continue
410
-
411
  combined_heatmap = np.zeros((150, 150))
412
  for pid in teams[team_id]:
413
  player_heatmap = performance_tracker.generate_heatmap(pid, resolution=150)
414
  combined_heatmap += player_heatmap
415
-
416
  if combined_heatmap.max() > 0:
417
  combined_heatmap = combined_heatmap / combined_heatmap.max()
418
-
419
  pitch = draw_pitch(CONFIG)
420
  padding = 50
421
  pitch_height, pitch_width = pitch.shape[:2]
422
- heatmap_resized = cv2.resize(combined_heatmap,
423
- (pitch_width - 2*padding, pitch_height - 2*padding))
424
-
 
 
425
  colormap = cv2.COLORMAP_JET if team_id == 0 else cv2.COLORMAP_HOT
426
  heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), colormap)
427
-
428
  overlay = pitch.copy()
429
- overlay[padding:pitch_height-padding, padding:pitch_width-padding] = heatmap_colored
430
  result = cv2.addWeighted(pitch, 0.5, overlay, 0.5, 0)
431
-
432
  team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)"
433
- cv2.putText(result, team_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
434
- 1, (255, 255, 255), 2, cv2.LINE_AA)
435
-
 
 
 
 
 
 
 
 
436
  team_heatmaps.append(result)
437
-
438
  if len(team_heatmaps) == 2:
439
  return np.hstack(team_heatmaps)
440
  elif len(team_heatmaps) == 1:
@@ -442,28 +502,36 @@ def create_combined_heatmaps(performance_tracker: PlayerPerformanceTracker) -> n
442
  else:
443
  return draw_pitch(CONFIG)
444
 
 
445
  # ==============================================
446
  # HELPER FUNCTIONS
447
  # ==============================================
448
  def resolve_goalkeepers_team_id(players: sv.Detections, goalkeepers: sv.Detections) -> np.ndarray:
449
- """Assign goalkeepers to the nearest team centroid"""
450
  if len(goalkeepers) == 0 or len(players) == 0:
451
  return np.array([])
452
  goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
453
  players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
454
  team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
455
  team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
456
- return np.array([
457
- 0 if np.linalg.norm(gk - team_0_centroid) < np.linalg.norm(gk - team_1_centroid) else 1
458
- for gk in goalkeepers_xy
459
- ])
 
 
460
 
461
 
462
- def create_game_style_radar(pitch_ball_xy, pitch_players_xy, players_class_id,
463
- pitch_referees_xy, ball_path=None):
464
- """Create game-style radar view with ball trail effect"""
 
 
 
 
 
465
  annotated_frame = draw_pitch(CONFIG)
466
-
467
  # Draw ball trail with fading effect
468
  if ball_path is not None and len(ball_path) > 0:
469
  valid_path = [coords for coords in ball_path if len(coords) > 0]
@@ -474,47 +542,52 @@ def create_game_style_radar(pitch_ball_xy, pitch_players_xy, players_class_id,
474
  alpha = (i + 1) / min(20, len(valid_path))
475
  color = sv.Color(int(255 * alpha), int(255 * alpha), int(255 * alpha))
476
  annotated_frame = draw_points_on_pitch(
477
- CONFIG, coords,
478
- face_color=color,
479
- edge_color=sv.Color.BLACK,
 
480
  radius=int(6 + alpha * 4),
481
- pitch=annotated_frame
482
  )
483
-
484
  # Draw current ball position
485
  if len(pitch_ball_xy) > 0:
486
  annotated_frame = draw_points_on_pitch(
487
- CONFIG, pitch_ball_xy,
488
- face_color=sv.Color.WHITE,
489
- edge_color=sv.Color.BLACK,
490
- radius=10,
491
- pitch=annotated_frame
 
492
  )
493
-
494
  # Draw players
495
  for team_id, color_hex in zip([0, 1], ["00BFFF", "FF1493"]):
496
  mask = players_class_id == team_id
497
  if np.any(mask):
498
  annotated_frame = draw_points_on_pitch(
499
- CONFIG, pitch_players_xy[mask],
500
- face_color=sv.Color.from_hex(color_hex),
501
- edge_color=sv.Color.BLACK,
502
- radius=16,
503
- pitch=annotated_frame
 
504
  )
505
-
506
  # Draw referees
507
  if len(pitch_referees_xy) > 0:
508
  annotated_frame = draw_points_on_pitch(
509
- CONFIG, pitch_referees_xy,
510
- face_color=sv.Color.from_hex("FFD700"),
511
- edge_color=sv.Color.BLACK,
512
- radius=16,
513
- pitch=annotated_frame
 
514
  )
515
-
516
  return annotated_frame
517
 
 
518
  # ==============================================
519
  # MAIN ANALYSIS PIPELINE
520
  # ==============================================
@@ -530,9 +603,17 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
530
  - Simple events + possession + per-player stats
531
  """
532
  if not video_path:
533
- return (None, None, None, None, None,
534
- "โŒ Please upload a video file.",
535
- [], [], None)
 
 
 
 
 
 
 
 
536
 
537
  try:
538
  progress(0, desc="๐Ÿ”ง Initializing...")
@@ -540,15 +621,23 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
540
  # IDs from Roboflow model
541
  BALL_ID, GOALKEEPER_ID, PLAYER_ID, REFEREE_ID = 0, 1, 2, 3
542
  STRIDE = 30 # Frame sampling for training
543
- MAXLEN = 5 # Transformation matrix smoothing
544
  MAX_DISTANCE_THRESHOLD = 500 # Ball path outlier threshold
545
 
546
  # Video setup
547
  cap = cv2.VideoCapture(video_path)
548
  if not cap.isOpened():
549
- return (None, None, None, None, None,
550
- f"โŒ Failed to open video: {video_path}",
551
- [], [], None)
 
 
 
 
 
 
 
 
552
 
553
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
554
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -569,10 +658,10 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
569
  performance_tracker = PlayerPerformanceTracker(CONFIG, fps=fps)
570
 
571
  # Simple possession / events stats
572
- distance_covered_m = defaultdict(float) # tid -> meters (for overlay if needed)
573
- possession_time_player = defaultdict(float) # tid -> seconds
574
- possession_time_team = defaultdict(float) # team_id -> seconds
575
- team_of_player = {} # tid -> team_id
576
  events: List[Dict] = []
577
 
578
  prev_owner_tid: Optional[int] = None
@@ -580,19 +669,19 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
580
 
581
  # Annotators
582
  ellipse_annotator = sv.EllipseAnnotator(
583
- color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
584
- thickness=2
585
  )
586
  label_annotator = sv.LabelAnnotator(
587
- color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
588
- text_color=sv.Color.from_hex('#FFFFFF'),
589
  text_thickness=2,
590
- text_position=sv.Position.BOTTOM_CENTER
591
  )
592
  triangle_annotator = sv.TriangleAnnotator(
593
- color=sv.Color.from_hex('#FFD700'),
594
- base=20,
595
- height=17
596
  )
597
 
598
  # ByteTrack tracker with optimized settings
@@ -600,7 +689,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
600
  track_activation_threshold=0.4,
601
  lost_track_buffer=60,
602
  minimum_matching_threshold=0.85,
603
- frame_rate=fps
604
  )
605
  tracker.reset()
606
 
@@ -632,7 +721,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
632
  progress(0.05, desc="๐Ÿƒ Collecting player samples (Step 1/6)...")
633
  player_crops = []
634
  frame_count = 0
635
-
636
  while frame_count < min(total_frames, 300):
637
  ret, frame = cap.read()
638
  if not ret:
@@ -652,9 +741,17 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
652
  if len(player_crops) == 0:
653
  cap.release()
654
  out.release()
655
- return (None, None, None, None, None,
656
- "โŒ No player crops collected.",
657
- [], [], None)
 
 
 
 
 
 
 
 
658
 
659
  print(f"โœ… Collected {len(player_crops)} player samples")
660
 
@@ -673,7 +770,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
673
  frame_count = 0
674
 
675
  progress(0.2, desc="๐ŸŽฌ Processing video frames (Step 3/6)...")
676
-
677
  frame_idx = 0
678
  while True:
679
  ret, frame = cap.read()
@@ -684,10 +781,12 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
684
  t = frame_idx * dt
685
  frame_count += 1
686
  tracking_manager.reset_frame()
687
-
688
  if frame_count % 30 == 0:
689
- progress(0.2 + 0.4 * (frame_count / total_frames),
690
- desc=f"๐ŸŽฌ Processing frame {frame_count}/{total_frames}")
 
 
691
 
692
  # Player and ball detection
693
  _, detections = infer_with_confidence(PLAYER_DETECTION_MODEL_ID, frame, 0.3)
@@ -700,10 +799,10 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
700
  # Separate ball from other detections
701
  ball_detections = detections[detections.class_id == BALL_ID]
702
  ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
703
-
704
  all_detections = detections[detections.class_id != BALL_ID]
705
  all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
706
-
707
  # Track detections
708
  all_detections = tracker.update_with_detections(detections=all_detections)
709
 
@@ -716,29 +815,31 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
716
  if len(players_detections.xyxy) > 0:
717
  crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
718
  predicted_teams = team_classifier.predict(crops)
719
-
720
  # Apply stable team assignment
721
  for idx, tracker_id in enumerate(players_detections.tracker_id):
722
  tracking_manager.update_team_assignment(int(tracker_id), int(predicted_teams[idx]))
723
  predicted_teams[idx] = tracking_manager.get_stable_team_id(
724
- int(tracker_id), int(predicted_teams[idx])
 
725
  )
726
-
727
  players_detections.class_id = predicted_teams
728
 
729
  # Assign goalkeeper teams
730
  goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
731
- players_detections, goalkeepers_detections
 
732
  )
733
 
734
  # Adjust referee class_id
735
  referees_detections.class_id -= 1
736
 
737
  # Merge all detections
738
- all_detections = sv.Detections.merge([
739
- players_detections, goalkeepers_detections, referees_detections
740
- ])
741
-
742
  all_detections.class_id = all_detections.class_id.astype(int)
743
 
744
  # ========================================
@@ -752,48 +853,66 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
752
  try:
753
  result_field, _ = infer_with_confidence(FIELD_DETECTION_MODEL_ID, frame, 0.3)
754
  key_points = sv.KeyPoints.from_inference(result_field)
755
-
756
  # Filter confident keypoints
757
  filter_mask = key_points.confidence[0] > 0.5
758
  frame_ref_pts = key_points.xy[0][filter_mask]
759
  pitch_ref_pts = np.array(CONFIG.vertices)[filter_mask]
760
-
761
  if len(frame_ref_pts) >= 4: # Need at least 4 points for homography
762
  transformer = ViewTransformer(source=frame_ref_pts, target=pitch_ref_pts)
763
  M.append(transformer.m)
764
  transformer.m = np.mean(np.array(M), axis=0)
765
 
766
  # Transform ball position
767
- frame_ball_xy = ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
768
- pitch_ball_xy = transformer.transform_points(frame_ball_xy) if len(frame_ball_xy) > 0 else np.empty((0, 2))
 
 
 
 
 
 
769
  if len(pitch_ball_xy) > 0:
770
  frame_ball_pos_pitch = pitch_ball_xy[0]
771
  ball_path_raw.append(pitch_ball_xy)
772
 
773
  # Transform all players (including goalkeepers)
774
  all_players = sv.Detections.merge([players_detections, goalkeepers_detections])
775
- players_xy = all_players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
776
- pitch_players_xy = transformer.transform_points(players_xy) if len(players_xy) > 0 else np.empty((0, 2))
777
-
 
 
 
 
 
 
778
  # Transform referees
779
- referees_xy = referees_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
780
- pitch_referees_xy = transformer.transform_points(referees_xy) if len(referees_xy) > 0 else np.empty((0, 2))
781
-
 
 
 
 
 
 
782
  # Store for radar view
783
  last_pitch_players_xy = pitch_players_xy
784
  last_players_class_id = all_players.class_id
785
  last_pitch_referees_xy = pitch_referees_xy
786
-
787
  # Update performance tracker + distance per player (meters)
788
  for idx, tracker_id in enumerate(all_players.tracker_id):
789
  tid_int = int(tracker_id)
790
  if idx < len(pitch_players_xy):
791
  pos_pitch = pitch_players_xy[idx]
792
  performance_tracker.update(
793
- tid_int,
794
- pos_pitch,
795
  int(all_players.class_id[idx]),
796
- frame_count
797
  )
798
  team_of_player[tid_int] = int(all_players.class_id[idx])
799
 
@@ -877,7 +996,10 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
877
  "from_tid": int(prev_owner_tid),
878
  "to_tid": int(owner_tid),
879
  "team_id": int(cur_team),
880
- "extra": {"player_distance_m": d_pp, "ball_travel_m": travel_m},
 
 
 
881
  },
882
  f"{label}: #{owner_tid} wins ball from #{prev_owner_tid}",
883
  )
@@ -889,12 +1011,14 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
889
  {
890
  "type": "possession_change",
891
  "t": float(t),
892
- "from_tid": int(prev_owner_tid) if prev_owner_tid is not None else None,
 
 
893
  "to_tid": int(owner_tid),
894
  "team_id": int(team_id) if team_id is not None else None,
895
  "extra": {},
896
  },
897
- "" # no extra banner for generic changes
898
  )
899
 
900
  # shot / clearance based on ball speed & direction
@@ -903,7 +1027,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
903
  and frame_ball_pos_pitch is not None
904
  and owner_tid is not None
905
  ):
906
- v_vec = (frame_ball_pos_pitch - prev_ball_pos_pitch) # pitch units
907
  # convert to meters per second
908
  dist_m = pitch_distance_m(prev_ball_pos_pitch, frame_ball_pos_pitch)
909
  speed_mps = dist_m / dt
@@ -960,7 +1084,11 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
960
  labels.append(f"#{int(tid)} T{int(cid)}")
961
 
962
  annotated_frame = ellipse_annotator.annotate(annotated_frame, all_detections)
963
- annotated_frame = label_annotator.annotate(annotated_frame, all_detections, labels=labels)
 
 
 
 
964
  annotated_frame = triangle_annotator.annotate(annotated_frame, ball_detections)
965
 
966
  # HUD: possession per team
@@ -968,7 +1096,10 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
968
  team0_pct = 100.0 * possession_time_team.get(0, 0.0) / total_poss
969
  team1_pct = 100.0 * possession_time_team.get(1, 0.0) / total_poss
970
 
971
- hud_text = f"Team 0 Ball Control: {team0_pct:5.2f}% Team 1 Ball Control: {team1_pct:5.2f}%"
 
 
 
972
  cv2.rectangle(
973
  annotated_frame,
974
  (20, annotated_frame.shape[0] - 60),
@@ -994,7 +1125,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
994
  (20, 20),
995
  (annotated_frame.shape[1] - 20, 90),
996
  (255, 255, 255),
997
- -1
998
  )
999
  cv2.putText(
1000
  annotated_frame,
@@ -1018,7 +1149,7 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
1018
  # STEP 5: Clean Ball Path (Remove Outliers)
1019
  # ========================================
1020
  progress(0.65, desc="๐Ÿงน Cleaning ball trajectory (Step 4/6)...")
1021
-
1022
  # Convert to proper format for cleaning
1023
  path_for_cleaning = []
1024
  for coords in ball_path_raw:
@@ -1029,58 +1160,66 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
1029
  path_for_cleaning.append(np.empty((0, 2), dtype=np.float32))
1030
  else:
1031
  path_for_cleaning.append(coords)
1032
-
1033
  # Remove outliers
1034
  cleaned_path = replace_outliers_based_on_distance(
1035
- [np.array(p).reshape(-1, 2) if len(p) > 0 else np.empty((0, 2)) for p in path_for_cleaning],
1036
- MAX_DISTANCE_THRESHOLD
 
 
 
 
 
 
 
 
1037
  )
1038
-
1039
- print(f"โœ… Ball path cleaned: {len([p for p in cleaned_path if len(p) > 0])} valid points")
1040
 
1041
  # ========================================
1042
  # STEP 6: Generate Performance Analytics
1043
  # ========================================
1044
  progress(0.75, desc="๐Ÿ“Š Generating performance analytics (Step 5/6)...")
1045
-
1046
  # Team comparison charts
1047
  comparison_fig = create_team_comparison_plot(performance_tracker)
1048
-
1049
  # Combined team heatmaps
1050
  team_heatmaps_path = "/tmp/team_heatmaps.png"
1051
  team_heatmaps = create_combined_heatmaps(performance_tracker)
1052
  cv2.imwrite(team_heatmaps_path, team_heatmaps)
1053
-
1054
  # Individual player heatmaps (top 6 by distance)
1055
  progress(0.85, desc="๐Ÿ—บ๏ธ Creating individual heatmaps...")
1056
  teams = performance_tracker.get_all_players_by_team()
1057
  top_players = []
1058
-
1059
  for team_id in [0, 1]:
1060
  if team_id in teams:
1061
  team_players = teams[team_id]
1062
- player_distances = [(pid, performance_tracker.get_player_stats(pid)['total_distance_meters'])
1063
- for pid in team_players]
 
 
1064
  player_distances.sort(key=lambda x: x[1], reverse=True)
1065
  top_players.extend([pid for pid, _ in player_distances[:3]])
1066
-
1067
  individual_heatmaps = []
1068
  for pid in top_players[:6]:
1069
  heatmap = create_player_heatmap_visualization(performance_tracker, pid)
1070
  individual_heatmaps.append(heatmap)
1071
-
1072
  # Arrange individual heatmaps in grid (3 columns)
1073
  if len(individual_heatmaps) > 0:
1074
  rows = []
1075
  for i in range(0, len(individual_heatmaps), 3):
1076
- row_maps = individual_heatmaps[i:i+3]
1077
  if len(row_maps) == 3:
1078
  rows.append(np.hstack(row_maps))
1079
  elif len(row_maps) == 2:
1080
  rows.append(np.hstack([row_maps[0], row_maps[1]]))
1081
  else:
1082
  rows.append(row_maps[0])
1083
-
1084
  individual_grid = np.vstack(rows) if len(rows) > 1 else rows[0]
1085
  individual_heatmaps_path = "/tmp/individual_heatmaps.png"
1086
  cv2.imwrite(individual_heatmaps_path, individual_grid)
@@ -1095,11 +1234,13 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
1095
  try:
1096
  if last_pitch_players_xy is not None:
1097
  radar_frame = create_game_style_radar(
1098
- pitch_ball_xy=cleaned_path[-1] if cleaned_path else np.empty((0, 2)),
 
 
1099
  pitch_players_xy=last_pitch_players_xy,
1100
  players_class_id=last_players_class_id,
1101
  pitch_referees_xy=last_pitch_referees_xy,
1102
- ball_path=cleaned_path
1103
  )
1104
  cv2.imwrite(radar_path, radar_frame)
1105
  else:
@@ -1122,14 +1263,14 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
1122
 
1123
  row = [
1124
  int(pid),
1125
- int(stats['team_id']),
1126
- float(stats['total_distance_meters']),
1127
- float(stats['avg_velocity']),
1128
- float(stats['max_velocity']),
1129
- int(stats['frames_visible']),
1130
- int(stats['time_in_defensive_third']),
1131
- int(stats['time_in_middle_third']),
1132
- int(stats['time_in_attacking_third']),
1133
  poss_s,
1134
  poss_pct,
1135
  ]
@@ -1151,11 +1292,19 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
1151
  if ev_type == "pass":
1152
  desc = f"Pass #{from_tid} โ†’ #{to_tid} (Team {team_id})"
1153
  elif ev_type == "tackle":
1154
- desc = f"Tackle: #{to_tid} wins ball from #{from_tid} (Team {team_id})"
 
 
 
1155
  elif ev_type == "interception":
1156
- desc = f"Interception: #{to_tid} intercepts #{from_tid} (Team {team_id})"
 
 
 
1157
  elif ev_type == "shot":
1158
- desc = f"Shot by #{from_tid} (Team {team_id}) at {speed_kmh:.1f} km/h"
 
 
1159
  elif ev_type == "clearance":
1160
  desc = f"Clearance by #{from_tid} (Team {team_id})"
1161
  else:
@@ -1184,32 +1333,41 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
1184
  progress(0.95, desc="๐Ÿ“ Generating summary report...")
1185
 
1186
  summary_lines = ["โœ… **Analysis Complete!**\n"]
1187
- summary_lines.append(f"**Video Statistics:**")
1188
  summary_lines.append(f"- Total Frames Processed: {frame_count}")
1189
  summary_lines.append(f"- Video Resolution: {width}x{height}")
1190
  summary_lines.append(f"- Frame Rate: {fps:.2f} fps")
1191
- summary_lines.append(f"- Ball Trajectory Points: {len([p for p in cleaned_path if len(p) > 0])}\n")
1192
-
 
 
 
1193
  for team_id in [0, 1]:
1194
  if team_id not in teams:
1195
  continue
1196
-
1197
  team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)"
1198
  summary_lines.append(f"\n**{team_name}:**")
1199
  summary_lines.append(f"- Players Tracked: {len(teams[team_id])}")
1200
-
1201
- total_dist = sum(performance_tracker.get_player_stats(pid)['total_distance_meters']
1202
- for pid in teams[team_id])
 
 
1203
  avg_dist = total_dist / len(teams[team_id]) if len(teams[team_id]) > 0 else 0
1204
  summary_lines.append(f"- Team Total Distance: {total_dist:.1f} m")
1205
- summary_lines.append(f"- Average Distance per Player: {avg_dist:.1f} m")
1206
-
 
 
1207
  # Top 3 performers (by distance)
1208
- player_distances = [(pid, performance_tracker.get_player_stats(pid)['total_distance_meters'])
1209
- for pid in teams[team_id]]
 
 
1210
  player_distances.sort(key=lambda x: x[1], reverse=True)
1211
-
1212
- summary_lines.append(f"\n **Top 3 Performers:**")
1213
  for i, (pid, dist) in enumerate(player_distances[:3], 1):
1214
  stats = performance_tracker.get_player_stats(pid)
1215
  summary_lines.append(
@@ -1223,10 +1381,8 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
1223
  for team_id in sorted(possession_time_team.keys()):
1224
  t_sec = possession_time_team[team_id]
1225
  pct = 100.0 * t_sec / total_poss if total_poss > 0 else 0.0
1226
- summary_lines.append(
1227
- f"- Team {team_id}: {t_sec:.1f} s ({pct:.1f}%)"
1228
- )
1229
-
1230
  summary_lines.append("\n**Pipeline Steps Completed:**")
1231
  summary_lines.append("โœ… 1. Player crop collection")
1232
  summary_lines.append("โœ… 2. Team classifier training")
@@ -1234,41 +1390,96 @@ def analyze_football_video(video_path: str, progress=gr.Progress()) -> Tuple:
1234
  summary_lines.append("โœ… 4. Ball trajectory cleaning")
1235
  summary_lines.append("โœ… 5. Performance analytics generation")
1236
  summary_lines.append("โœ… 6. Visualization creation")
1237
-
1238
  summary_msg = "\n".join(summary_lines)
1239
 
1240
  progress(1.0, desc="โœ… Analysis Complete!")
1241
 
1242
  # IMPORTANT: must return 9 outputs in the same order as Gradio wiring
1243
  return (
1244
- output_path, # video_output
1245
- comparison_fig, # comparison_output
1246
- team_heatmaps_path, # team_heatmaps_output
1247
  individual_heatmaps_path, # individual_heatmaps_output
1248
- radar_path, # radar_output
1249
- summary_msg, # status_output
1250
- player_stats_table, # player_stats_output (Dataframe)
1251
- events_table, # events_output (Dataframe)
1252
- events_json_path, # events_json_output (File download)
1253
  )
1254
 
1255
  except Exception as e:
1256
  error_msg = f"โŒ Error: {str(e)}"
1257
  print(error_msg)
1258
  import traceback
 
1259
  traceback.print_exc()
1260
  # Match the 9 outputs (fill with Nones/empties)
1261
  return (
1262
- None, None, None, None, None,
 
 
 
 
1263
  error_msg,
1264
- [], [], None
 
 
1265
  )
1266
 
 
1267
  # ==============================================
1268
  # GRADIO INTERFACE
1269
  # ==============================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1270
  with gr.Blocks(title="โšฝ Football Performance Analyzer", theme=gr.themes.Soft()) as iface:
1271
- gr.Markdown("""
 
1272
  # โšฝ Advanced Football Video Analyzer
1273
  ### Complete Pipeline Implementation
1274
 
@@ -1281,34 +1492,37 @@ with gr.Blocks(title="โšฝ Football Performance Analyzer", theme=gr.themes.Soft()
1281
  6. **Performance Analytics** - Heatmaps, stats, possession, and event detection
1282
 
1283
  Upload a football match video to get comprehensive performance analytics!
1284
- """)
1285
-
1286
- with gr.Row():
1287
- # No "type" argument here โ€“ Gradio's Video in your Space does not support it
 
1288
  video_input = gr.Video(label="๐Ÿ“ค Upload Football Video")
1289
-
1290
  analyze_btn = gr.Button("๐Ÿš€ Start Analysis Pipeline", variant="primary", size="lg")
1291
-
1292
  with gr.Row():
1293
  status_output = gr.Textbox(label="๐Ÿ“Š Analysis Summary & Statistics", lines=25)
1294
-
1295
  with gr.Tabs():
1296
  with gr.Tab("๐Ÿ“น Annotated Video"):
1297
- gr.Markdown("### Full video with player tracking, team colors, ball detection, and events overlay")
 
 
1298
  video_output = gr.Video(label="Processed Video")
1299
-
1300
  with gr.Tab("๐Ÿ“Š Performance Comparison"):
1301
  gr.Markdown("### Interactive charts comparing player performance metrics")
1302
  comparison_output = gr.Plot(label="Team Performance Metrics")
1303
-
1304
  with gr.Tab("๐Ÿ—บ๏ธ Team Heatmaps"):
1305
  gr.Markdown("### Combined activity heatmaps showing team positioning")
1306
  team_heatmaps_output = gr.Image(label="Team Activity Heatmaps")
1307
-
1308
  with gr.Tab("๐Ÿ‘ค Individual Heatmaps"):
1309
  gr.Markdown("### Top 6 players with detailed activity analysis")
1310
  individual_heatmaps_output = gr.Image(label="Top Players Heatmaps")
1311
-
1312
  with gr.Tab("๐ŸŽฎ Game Radar View"):
1313
  gr.Markdown("### Game-style tactical view with ball trail")
1314
  radar_output = gr.Image(label="Tactical Radar View")
@@ -1317,41 +1531,44 @@ with gr.Blocks(title="โšฝ Football Performance Analyzer", theme=gr.themes.Soft()
1317
  gr.Markdown("### Per-player totals: distance, speeds, zones, possession")
1318
  player_stats_output = gr.Dataframe(
1319
  headers=PLAYER_STATS_HEADERS,
1320
- col_count=len(PLAYER_STATS_HEADERS), # IMPORTANT: must equal headers length
1321
  row_count=0,
1322
- interactive=False
1323
  )
1324
 
1325
  with gr.Tab("โฑ๏ธ Event Timeline"):
1326
- gr.Markdown("### Detected passes, tackles, interceptions, shots, clearances")
 
 
1327
  events_output = gr.Dataframe(
1328
  headers=EVENT_HEADERS,
1329
- col_count=len(EVENT_HEADERS), # IMPORTANT: must equal headers length
1330
  row_count=0,
1331
- interactive=False
1332
  )
1333
  events_json_output = gr.File(
1334
  label="Download events JSON",
1335
- file_types=[".json"]
1336
  )
1337
-
1338
  analyze_btn.click(
1339
- fn=analyze_football_video,
1340
  inputs=[video_input],
1341
  outputs=[
1342
- video_output, # 1
1343
- comparison_output, # 2
1344
  team_heatmaps_output, # 3
1345
  individual_heatmaps_output, # 4
1346
- radar_output, # 5
1347
- status_output, # 6
1348
- player_stats_output, # 7
1349
- events_output, # 8
1350
- events_json_output, # 9
1351
- ]
1352
  )
1353
-
1354
- gr.Markdown("""
 
1355
  ---
1356
  ### ๐Ÿ”ง Technical Details:
1357
 
@@ -1380,7 +1597,9 @@ with gr.Blocks(title="โšฝ Football Performance Analyzer", theme=gr.themes.Soft()
1380
  - Passes, tackles, interceptions, shots, clearances
1381
  - Event banner overlay in video
1382
  - Full event list downloadable as JSON
1383
- """)
 
 
1384
 
1385
  if __name__ == "__main__":
1386
- iface.launch(share=True)
 
45
  # ==============================================
46
  CLIENT = InferenceHTTPClient(
47
  api_url="https://detect.roboflow.com",
48
+ api_key=ROBOFLOW_API_KEY,
49
  )
50
 
51
  PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
52
  FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
53
 
54
+
55
  def infer_with_confidence(model_id: str, frame: np.ndarray, confidence_threshold: float = 0.3):
56
+ """Run inference and filter by confidence threshold."""
57
  result = CLIENT.infer(frame, model_id=model_id)
58
  detections = sv.Detections.from_inference(result)
59
  # Filter by confidence
 
61
  detections = detections[detections.confidence > confidence_threshold]
62
  return result, detections
63
 
64
+
65
  # ==============================================
66
  # SIGLIP MODEL (Embeddings)
67
  # ==============================================
 
76
 
77
  # ==============================================
78
  # TABLE HEADERS FOR GRADIO DATAFRAMES
 
79
  # ==============================================
80
  PLAYER_STATS_HEADERS = [
81
  "Player ID",
 
108
  # ==============================================
109
  def replace_outliers_based_on_distance(
110
  positions: List[np.ndarray],
111
+ distance_threshold: float,
112
  ) -> List[np.ndarray]:
113
+ """Remove outlier positions based on distance threshold."""
114
  last_valid_position: Union[np.ndarray, None] = None
115
  cleaned_positions: List[np.ndarray] = []
116
 
 
131
 
132
  return cleaned_positions
133
 
134
+
135
  # ==============================================
136
  # PITCH DISTANCE (UNITS FIX: meters)
137
  # ==============================================
 
151
  else:
152
  return d
153
 
154
+
155
  # ==============================================
156
  # PLAYER PERFORMANCE TRACKING
157
  # ==============================================
158
  class PlayerPerformanceTracker:
159
+ """Track individual player performance metrics and generate heatmaps."""
160
+
161
  def __init__(self, pitch_config, fps: float = 30.0):
162
  self.config = pitch_config
163
  self.fps = fps
164
  self.player_positions = defaultdict(list)
165
+ self.player_velocities = defaultdict(list) # km/h
166
+ self.player_distances = defaultdict(float) # meters
167
  self.player_team = {}
168
+ self.player_stats = defaultdict(
169
+ lambda: {
170
+ "frames_visible": 0,
171
+ "avg_velocity": 0.0, # km/h
172
+ "max_velocity": 0.0, # km/h
173
+ "time_in_attacking_third": 0,
174
+ "time_in_defensive_third": 0,
175
+ "time_in_middle_third": 0,
176
+ }
177
+ )
178
+
179
  def update(self, tracker_id: int, position: np.ndarray, team_id: int, frame: int):
180
+ """Update player position and calculate metrics."""
181
  if len(position) != 2:
182
  return
183
+
184
  self.player_team[tracker_id] = team_id
185
  self.player_positions[tracker_id].append((position[0], position[1], frame))
186
+ self.player_stats[tracker_id]["frames_visible"] += 1
187
+
188
  if len(self.player_positions[tracker_id]) > 1:
189
  prev_pos = np.array(self.player_positions[tracker_id][-2][:2], dtype=float)
190
  curr_pos = np.array(position, dtype=float)
 
192
  # distance in meters between frames
193
  distance_m = pitch_distance_m(prev_pos, curr_pos)
194
  self.player_distances[tracker_id] += distance_m
195
+
196
  # speed in km/h
197
  speed_mps = distance_m * self.fps
198
  speed_kmh = speed_mps * 3.6
199
  self.player_velocities[tracker_id].append(speed_kmh)
200
+
201
+ if speed_kmh > self.player_stats[tracker_id]["max_velocity"]:
202
+ self.player_stats[tracker_id]["max_velocity"] = speed_kmh
203
+
204
  pitch_length = self.config.length
205
  if position[0] < pitch_length / 3:
206
+ self.player_stats[tracker_id]["time_in_defensive_third"] += 1
207
  elif position[0] < 2 * pitch_length / 3:
208
+ self.player_stats[tracker_id]["time_in_middle_third"] += 1
209
  else:
210
+ self.player_stats[tracker_id]["time_in_attacking_third"] += 1
211
+
212
  def get_player_stats(self, tracker_id: int) -> dict:
213
+ """Get comprehensive stats for a player."""
214
  stats = self.player_stats[tracker_id].copy()
215
+
216
  if len(self.player_velocities[tracker_id]) > 0:
217
+ stats["avg_velocity"] = float(np.mean(self.player_velocities[tracker_id]))
218
+
219
+ stats["total_distance_meters"] = float(self.player_distances[tracker_id])
220
+ stats["team_id"] = int(self.player_team.get(tracker_id, -1))
221
+
222
  return stats
223
+
224
  def generate_heatmap(self, tracker_id: int, resolution: int = 100) -> np.ndarray:
225
+ """Generate heatmap for a specific player."""
226
  if tracker_id not in self.player_positions or len(self.player_positions[tracker_id]) == 0:
227
  return np.zeros((resolution, resolution))
228
+
229
  positions = np.array([(x, y) for x, y, _ in self.player_positions[tracker_id]])
230
+
231
  pitch_length = self.config.length
232
  pitch_width = self.config.width
233
+
234
  heatmap, xedges, yedges = np.histogram2d(
235
+ positions[:, 0],
236
+ positions[:, 1],
237
  bins=[resolution, resolution],
238
+ range=[[0, pitch_length], [0, pitch_width]],
239
  )
240
+
241
  heatmap = gaussian_filter(heatmap, sigma=3)
242
+
243
  return heatmap.T
244
+
245
  def get_all_players_by_team(self) -> Dict[int, List[int]]:
246
+ """Get all player IDs grouped by team."""
247
  teams = defaultdict(list)
248
  for tracker_id, team_id in self.player_team.items():
249
  teams[team_id].append(tracker_id)
250
  return teams
251
 
252
+
253
  # ==============================================
254
  # TRACKING MANAGER
255
  # ==============================================
256
  class PlayerTrackingManager:
257
+ """Manages persistent player tracking with team assignment stability."""
258
+
259
  def __init__(self, max_history=10):
260
  self.tracker_team_history: Dict[int, List[int]] = defaultdict(list)
261
  self.max_history = max_history
262
  self.active_trackers = set()
263
+
264
  def update_team_assignment(self, tracker_id: int, team_id: int):
265
+ """Store team assignment history for each tracker."""
266
  self.tracker_team_history[tracker_id].append(team_id)
267
  if len(self.tracker_team_history[tracker_id]) > self.max_history:
268
  self.tracker_team_history[tracker_id].pop(0)
269
  self.active_trackers.add(tracker_id)
270
+
271
  def get_stable_team_id(self, tracker_id: int, current_team_id: int) -> int:
272
+ """Get stable team ID using majority voting from history."""
273
  if tracker_id not in self.tracker_team_history or len(self.tracker_team_history[tracker_id]) < 3:
274
  return current_team_id
275
+
276
  history = self.tracker_team_history[tracker_id]
277
  team_counts = np.bincount(history)
278
  stable_team = int(np.argmax(team_counts))
279
  return stable_team
280
+
281
  def get_player_count_by_team(self) -> Dict[int, int]:
282
+ """Get current count of players per team."""
283
  team_counts = defaultdict(int)
284
  for tracker_id in self.active_trackers:
285
  if tracker_id in self.tracker_team_history and len(self.tracker_team_history[tracker_id]) > 0:
286
+ stable_team = self.get_stable_team_id(
287
+ tracker_id,
288
+ self.tracker_team_history[tracker_id][-1],
289
+ )
290
  team_counts[stable_team] += 1
291
  return team_counts
292
+
293
  def reset_frame(self):
294
+ """Reset active trackers for new frame."""
295
  self.active_trackers = set()
296
 
297
+
298
  # ==============================================
299
  # VISUALIZATION FUNCTIONS
300
  # ==============================================
301
+ def create_player_heatmap_visualization(
302
+ performance_tracker: PlayerPerformanceTracker,
303
+ tracker_id: int,
304
+ ) -> np.ndarray:
305
+ """Create a single player heatmap overlay on pitch."""
306
  pitch = draw_pitch(CONFIG)
307
  heatmap = performance_tracker.generate_heatmap(tracker_id, resolution=150)
308
+
309
  if heatmap.max() > 0:
310
  heatmap = heatmap / heatmap.max()
311
+
312
  padding = 50
313
+
314
  pitch_height, pitch_width = pitch.shape[:2]
315
+ heatmap_resized = cv2.resize(heatmap, (pitch_width - 2 * padding, pitch_height - 2 * padding))
316
+
317
  heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
318
+
319
  overlay = pitch.copy()
320
+ overlay[padding : pitch_height - padding, padding : pitch_width - padding] = heatmap_colored
321
+
322
  result = cv2.addWeighted(pitch, 0.6, overlay, 0.4, 0)
323
+
324
  stats = performance_tracker.get_player_stats(tracker_id)
325
+ team_color = "Blue" if stats["team_id"] == 0 else "Pink"
326
+
327
  text_lines = [
328
  f"Player #{tracker_id} ({team_color} Team)",
329
  f"Distance: {stats['total_distance_meters']:.1f} m",
330
  f"Avg Speed: {stats['avg_velocity']:.2f} km/h",
331
  f"Max Speed: {stats['max_velocity']:.2f} km/h",
332
+ f"Frames: {stats['frames_visible']}",
333
  ]
334
+
335
  y_offset = 30
336
  for line in text_lines:
337
+ cv2.putText(
338
+ result,
339
+ line,
340
+ (10, y_offset),
341
+ cv2.FONT_HERSHEY_SIMPLEX,
342
+ 0.6,
343
+ (255, 255, 255),
344
+ 2,
345
+ cv2.LINE_AA,
346
+ )
347
  y_offset += 25
348
+
349
  return result
350
 
351
 
352
  def create_team_comparison_plot(performance_tracker: PlayerPerformanceTracker) -> go.Figure:
353
+ """Create interactive performance comparison plots."""
354
  teams = performance_tracker.get_all_players_by_team()
355
+
356
  fig = make_subplots(
357
+ rows=2,
358
+ cols=2,
359
+ subplot_titles=(
360
+ "Distance Covered",
361
+ "Average Speed",
362
+ "Max Speed",
363
+ "Activity by Zone",
364
+ ),
365
+ specs=[[{"type": "bar"}, {"type": "bar"}], [{"type": "bar"}, {"type": "bar"}]],
366
  )
367
+
368
+ colors = {0: "#00BFFF", 1: "#FF1493"}
369
+ team_names = {0: "Team 0 (Blue)", 1: "Team 1 (Pink)"}
370
+
371
  for team_id, player_ids in teams.items():
372
  if team_id not in [0, 1]:
373
  continue
374
+
375
  distances = []
376
  avg_speeds = []
377
  max_speeds = []
378
  attacking_time = []
379
+
380
  for pid in player_ids:
381
  stats = performance_tracker.get_player_stats(pid)
382
+ distances.append(stats["total_distance_meters"])
383
+ avg_speeds.append(stats["avg_velocity"]) # km/h
384
+ max_speeds.append(stats["max_velocity"]) # km/h
385
+ attacking_time.append(stats["time_in_attacking_third"])
386
+
387
  player_labels = [f"#{pid}" for pid in player_ids]
388
+
389
  fig.add_trace(
390
+ go.Bar(
391
+ x=player_labels,
392
+ y=distances,
393
+ name=team_names[team_id],
394
+ marker_color=colors[team_id],
395
+ showlegend=True,
396
+ ),
397
+ row=1,
398
+ col=1,
399
  )
400
+
401
  fig.add_trace(
402
+ go.Bar(
403
+ x=player_labels,
404
+ y=avg_speeds,
405
+ name=team_names[team_id],
406
+ marker_color=colors[team_id],
407
+ showlegend=False,
408
+ ),
409
+ row=1,
410
+ col=2,
411
  )
412
+
413
  fig.add_trace(
414
+ go.Bar(
415
+ x=player_labels,
416
+ y=max_speeds,
417
+ name=team_names[team_id],
418
+ marker_color=colors[team_id],
419
+ showlegend=False,
420
+ ),
421
+ row=2,
422
+ col=1,
423
  )
424
+
425
  fig.add_trace(
426
+ go.Bar(
427
+ x=player_labels,
428
+ y=attacking_time,
429
+ name=team_names[team_id],
430
+ marker_color=colors[team_id],
431
+ showlegend=False,
432
+ ),
433
+ row=2,
434
+ col=2,
435
  )
436
+
437
  fig.update_xaxes(title_text="Players", row=1, col=1)
438
  fig.update_xaxes(title_text="Players", row=1, col=2)
439
  fig.update_xaxes(title_text="Players", row=2, col=1)
440
  fig.update_xaxes(title_text="Players", row=2, col=2)
441
+
442
  fig.update_yaxes(title_text="Distance (m)", row=1, col=1)
443
  fig.update_yaxes(title_text="Speed (km/h)", row=1, col=2)
444
  fig.update_yaxes(title_text="Speed (km/h)", row=2, col=1)
445
  fig.update_yaxes(title_text="Frames in Zone", row=2, col=2)
446
+
447
+ fig.update_layout(height=800, title_text="Team Performance Comparison", barmode="group")
448
+
449
  return fig
450
 
451
 
452
  def create_combined_heatmaps(performance_tracker: PlayerPerformanceTracker) -> np.ndarray:
453
+ """Create side-by-side team heatmaps."""
454
  teams = performance_tracker.get_all_players_by_team()
455
+
456
  team_heatmaps = []
457
  for team_id in [0, 1]:
458
  if team_id not in teams:
459
  continue
460
+
461
  combined_heatmap = np.zeros((150, 150))
462
  for pid in teams[team_id]:
463
  player_heatmap = performance_tracker.generate_heatmap(pid, resolution=150)
464
  combined_heatmap += player_heatmap
465
+
466
  if combined_heatmap.max() > 0:
467
  combined_heatmap = combined_heatmap / combined_heatmap.max()
468
+
469
  pitch = draw_pitch(CONFIG)
470
  padding = 50
471
  pitch_height, pitch_width = pitch.shape[:2]
472
+ heatmap_resized = cv2.resize(
473
+ combined_heatmap,
474
+ (pitch_width - 2 * padding, pitch_height - 2 * padding),
475
+ )
476
+
477
  colormap = cv2.COLORMAP_JET if team_id == 0 else cv2.COLORMAP_HOT
478
  heatmap_colored = cv2.applyColorMap((heatmap_resized * 255).astype(np.uint8), colormap)
479
+
480
  overlay = pitch.copy()
481
+ overlay[padding : pitch_height - padding, padding : pitch_width - padding] = heatmap_colored
482
  result = cv2.addWeighted(pitch, 0.5, overlay, 0.5, 0)
483
+
484
  team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)"
485
+ cv2.putText(
486
+ result,
487
+ team_name,
488
+ (10, 30),
489
+ cv2.FONT_HERSHEY_SIMPLEX,
490
+ 1,
491
+ (255, 255, 255),
492
+ 2,
493
+ cv2.LINE_AA,
494
+ )
495
+
496
  team_heatmaps.append(result)
497
+
498
  if len(team_heatmaps) == 2:
499
  return np.hstack(team_heatmaps)
500
  elif len(team_heatmaps) == 1:
 
502
  else:
503
  return draw_pitch(CONFIG)
504
 
505
+
506
  # ==============================================
507
  # HELPER FUNCTIONS
508
  # ==============================================
509
  def resolve_goalkeepers_team_id(players: sv.Detections, goalkeepers: sv.Detections) -> np.ndarray:
510
+ """Assign goalkeepers to the nearest team centroid."""
511
  if len(goalkeepers) == 0 or len(players) == 0:
512
  return np.array([])
513
  goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
514
  players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
515
  team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
516
  team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
517
+ return np.array(
518
+ [
519
+ 0 if np.linalg.norm(gk - team_0_centroid) < np.linalg.norm(gk - team_1_centroid) else 1
520
+ for gk in goalkeepers_xy
521
+ ]
522
+ )
523
 
524
 
525
+ def create_game_style_radar(
526
+ pitch_ball_xy,
527
+ pitch_players_xy,
528
+ players_class_id,
529
+ pitch_referees_xy,
530
+ ball_path=None,
531
+ ):
532
+ """Create game-style radar view with ball trail effect."""
533
  annotated_frame = draw_pitch(CONFIG)
534
+
535
  # Draw ball trail with fading effect
536
  if ball_path is not None and len(ball_path) > 0:
537
  valid_path = [coords for coords in ball_path if len(coords) > 0]
 
542
  alpha = (i + 1) / min(20, len(valid_path))
543
  color = sv.Color(int(255 * alpha), int(255 * alpha), int(255 * alpha))
544
  annotated_frame = draw_points_on_pitch(
545
+ CONFIG,
546
+ coords,
547
+ face_color=color,
548
+ edge_color=sv.Color.BLACK,
549
  radius=int(6 + alpha * 4),
550
+ pitch=annotated_frame,
551
  )
552
+
553
  # Draw current ball position
554
  if len(pitch_ball_xy) > 0:
555
  annotated_frame = draw_points_on_pitch(
556
+ CONFIG,
557
+ pitch_ball_xy,
558
+ face_color=sv.Color.WHITE,
559
+ edge_color=sv.Color.BLACK,
560
+ radius=10,
561
+ pitch=annotated_frame,
562
  )
563
+
564
  # Draw players
565
  for team_id, color_hex in zip([0, 1], ["00BFFF", "FF1493"]):
566
  mask = players_class_id == team_id
567
  if np.any(mask):
568
  annotated_frame = draw_points_on_pitch(
569
+ CONFIG,
570
+ pitch_players_xy[mask],
571
+ face_color=sv.Color.from_hex(color_hex),
572
+ edge_color=sv.Color.BLACK,
573
+ radius=16,
574
+ pitch=annotated_frame,
575
  )
576
+
577
  # Draw referees
578
  if len(pitch_referees_xy) > 0:
579
  annotated_frame = draw_points_on_pitch(
580
+ CONFIG,
581
+ pitch_referees_xy,
582
+ face_color=sv.Color.from_hex("FFD700"),
583
+ edge_color=sv.Color.BLACK,
584
+ radius=16,
585
+ pitch=annotated_frame,
586
  )
587
+
588
  return annotated_frame
589
 
590
+
591
  # ==============================================
592
  # MAIN ANALYSIS PIPELINE
593
  # ==============================================
 
603
  - Simple events + possession + per-player stats
604
  """
605
  if not video_path:
606
+ return (
607
+ None,
608
+ None,
609
+ None,
610
+ None,
611
+ None,
612
+ "โŒ Please upload a video file.",
613
+ [],
614
+ [],
615
+ None,
616
+ )
617
 
618
  try:
619
  progress(0, desc="๐Ÿ”ง Initializing...")
 
621
  # IDs from Roboflow model
622
  BALL_ID, GOALKEEPER_ID, PLAYER_ID, REFEREE_ID = 0, 1, 2, 3
623
  STRIDE = 30 # Frame sampling for training
624
+ MAXLEN = 5 # Transformation matrix smoothing
625
  MAX_DISTANCE_THRESHOLD = 500 # Ball path outlier threshold
626
 
627
  # Video setup
628
  cap = cv2.VideoCapture(video_path)
629
  if not cap.isOpened():
630
+ return (
631
+ None,
632
+ None,
633
+ None,
634
+ None,
635
+ None,
636
+ f"โŒ Failed to open video: {video_path}",
637
+ [],
638
+ [],
639
+ None,
640
+ )
641
 
642
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
643
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
 
658
  performance_tracker = PlayerPerformanceTracker(CONFIG, fps=fps)
659
 
660
  # Simple possession / events stats
661
+ distance_covered_m = defaultdict(float) # tid -> meters
662
+ possession_time_player = defaultdict(float) # tid -> seconds
663
+ possession_time_team = defaultdict(float) # team_id -> seconds
664
+ team_of_player = {} # tid -> team_id
665
  events: List[Dict] = []
666
 
667
  prev_owner_tid: Optional[int] = None
 
669
 
670
  # Annotators
671
  ellipse_annotator = sv.EllipseAnnotator(
672
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
673
+ thickness=2,
674
  )
675
  label_annotator = sv.LabelAnnotator(
676
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
677
+ text_color=sv.Color.from_hex("#FFFFFF"),
678
  text_thickness=2,
679
+ text_position=sv.Position.BOTTOM_CENTER,
680
  )
681
  triangle_annotator = sv.TriangleAnnotator(
682
+ color=sv.Color.from_hex("#FFD700"),
683
+ base=20,
684
+ height=17,
685
  )
686
 
687
  # ByteTrack tracker with optimized settings
 
689
  track_activation_threshold=0.4,
690
  lost_track_buffer=60,
691
  minimum_matching_threshold=0.85,
692
+ frame_rate=fps,
693
  )
694
  tracker.reset()
695
 
 
721
  progress(0.05, desc="๐Ÿƒ Collecting player samples (Step 1/6)...")
722
  player_crops = []
723
  frame_count = 0
724
+
725
  while frame_count < min(total_frames, 300):
726
  ret, frame = cap.read()
727
  if not ret:
 
741
  if len(player_crops) == 0:
742
  cap.release()
743
  out.release()
744
+ return (
745
+ None,
746
+ None,
747
+ None,
748
+ None,
749
+ None,
750
+ "โŒ No player crops collected.",
751
+ [],
752
+ [],
753
+ None,
754
+ )
755
 
756
  print(f"โœ… Collected {len(player_crops)} player samples")
757
 
 
770
  frame_count = 0
771
 
772
  progress(0.2, desc="๐ŸŽฌ Processing video frames (Step 3/6)...")
773
+
774
  frame_idx = 0
775
  while True:
776
  ret, frame = cap.read()
 
781
  t = frame_idx * dt
782
  frame_count += 1
783
  tracking_manager.reset_frame()
784
+
785
  if frame_count % 30 == 0:
786
+ progress(
787
+ 0.2 + 0.4 * (frame_count / max(total_frames, 1)),
788
+ desc=f"๐ŸŽฌ Processing frame {frame_count}/{total_frames}",
789
+ )
790
 
791
  # Player and ball detection
792
  _, detections = infer_with_confidence(PLAYER_DETECTION_MODEL_ID, frame, 0.3)
 
799
  # Separate ball from other detections
800
  ball_detections = detections[detections.class_id == BALL_ID]
801
  ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
802
+
803
  all_detections = detections[detections.class_id != BALL_ID]
804
  all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
805
+
806
  # Track detections
807
  all_detections = tracker.update_with_detections(detections=all_detections)
808
 
 
815
  if len(players_detections.xyxy) > 0:
816
  crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
817
  predicted_teams = team_classifier.predict(crops)
818
+
819
  # Apply stable team assignment
820
  for idx, tracker_id in enumerate(players_detections.tracker_id):
821
  tracking_manager.update_team_assignment(int(tracker_id), int(predicted_teams[idx]))
822
  predicted_teams[idx] = tracking_manager.get_stable_team_id(
823
+ int(tracker_id),
824
+ int(predicted_teams[idx]),
825
  )
826
+
827
  players_detections.class_id = predicted_teams
828
 
829
  # Assign goalkeeper teams
830
  goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
831
+ players_detections,
832
+ goalkeepers_detections,
833
  )
834
 
835
  # Adjust referee class_id
836
  referees_detections.class_id -= 1
837
 
838
  # Merge all detections
839
+ all_detections = sv.Detections.merge(
840
+ [players_detections, goalkeepers_detections, referees_detections]
841
+ )
842
+
843
  all_detections.class_id = all_detections.class_id.astype(int)
844
 
845
  # ========================================
 
853
  try:
854
  result_field, _ = infer_with_confidence(FIELD_DETECTION_MODEL_ID, frame, 0.3)
855
  key_points = sv.KeyPoints.from_inference(result_field)
856
+
857
  # Filter confident keypoints
858
  filter_mask = key_points.confidence[0] > 0.5
859
  frame_ref_pts = key_points.xy[0][filter_mask]
860
  pitch_ref_pts = np.array(CONFIG.vertices)[filter_mask]
861
+
862
  if len(frame_ref_pts) >= 4: # Need at least 4 points for homography
863
  transformer = ViewTransformer(source=frame_ref_pts, target=pitch_ref_pts)
864
  M.append(transformer.m)
865
  transformer.m = np.mean(np.array(M), axis=0)
866
 
867
  # Transform ball position
868
+ frame_ball_xy = ball_detections.get_anchors_coordinates(
869
+ sv.Position.BOTTOM_CENTER
870
+ )
871
+ pitch_ball_xy = (
872
+ transformer.transform_points(frame_ball_xy)
873
+ if len(frame_ball_xy) > 0
874
+ else np.empty((0, 2))
875
+ )
876
  if len(pitch_ball_xy) > 0:
877
  frame_ball_pos_pitch = pitch_ball_xy[0]
878
  ball_path_raw.append(pitch_ball_xy)
879
 
880
  # Transform all players (including goalkeepers)
881
  all_players = sv.Detections.merge([players_detections, goalkeepers_detections])
882
+ players_xy = all_players.get_anchors_coordinates(
883
+ sv.Position.BOTTOM_CENTER
884
+ )
885
+ pitch_players_xy = (
886
+ transformer.transform_points(players_xy)
887
+ if len(players_xy) > 0
888
+ else np.empty((0, 2))
889
+ )
890
+
891
  # Transform referees
892
+ referees_xy = referees_detections.get_anchors_coordinates(
893
+ sv.Position.BOTTOM_CENTER
894
+ )
895
+ pitch_referees_xy = (
896
+ transformer.transform_points(referees_xy)
897
+ if len(referees_xy) > 0
898
+ else np.empty((0, 2))
899
+ )
900
+
901
  # Store for radar view
902
  last_pitch_players_xy = pitch_players_xy
903
  last_players_class_id = all_players.class_id
904
  last_pitch_referees_xy = pitch_referees_xy
905
+
906
  # Update performance tracker + distance per player (meters)
907
  for idx, tracker_id in enumerate(all_players.tracker_id):
908
  tid_int = int(tracker_id)
909
  if idx < len(pitch_players_xy):
910
  pos_pitch = pitch_players_xy[idx]
911
  performance_tracker.update(
912
+ tid_int,
913
+ pos_pitch,
914
  int(all_players.class_id[idx]),
915
+ frame_count,
916
  )
917
  team_of_player[tid_int] = int(all_players.class_id[idx])
918
 
 
996
  "from_tid": int(prev_owner_tid),
997
  "to_tid": int(owner_tid),
998
  "team_id": int(cur_team),
999
+ "extra": {
1000
+ "player_distance_m": d_pp,
1001
+ "ball_travel_m": travel_m,
1002
+ },
1003
  },
1004
  f"{label}: #{owner_tid} wins ball from #{prev_owner_tid}",
1005
  )
 
1011
  {
1012
  "type": "possession_change",
1013
  "t": float(t),
1014
+ "from_tid": int(prev_owner_tid)
1015
+ if prev_owner_tid is not None
1016
+ else None,
1017
  "to_tid": int(owner_tid),
1018
  "team_id": int(team_id) if team_id is not None else None,
1019
  "extra": {},
1020
  },
1021
+ "",
1022
  )
1023
 
1024
  # shot / clearance based on ball speed & direction
 
1027
  and frame_ball_pos_pitch is not None
1028
  and owner_tid is not None
1029
  ):
1030
+ v_vec = frame_ball_pos_pitch - prev_ball_pos_pitch # pitch units
1031
  # convert to meters per second
1032
  dist_m = pitch_distance_m(prev_ball_pos_pitch, frame_ball_pos_pitch)
1033
  speed_mps = dist_m / dt
 
1084
  labels.append(f"#{int(tid)} T{int(cid)}")
1085
 
1086
  annotated_frame = ellipse_annotator.annotate(annotated_frame, all_detections)
1087
+ annotated_frame = label_annotator.annotate(
1088
+ annotated_frame,
1089
+ all_detections,
1090
+ labels=labels,
1091
+ )
1092
  annotated_frame = triangle_annotator.annotate(annotated_frame, ball_detections)
1093
 
1094
  # HUD: possession per team
 
1096
  team0_pct = 100.0 * possession_time_team.get(0, 0.0) / total_poss
1097
  team1_pct = 100.0 * possession_time_team.get(1, 0.0) / total_poss
1098
 
1099
+ hud_text = (
1100
+ f"Team 0 Ball Control: {team0_pct:5.2f}% "
1101
+ f"Team 1 Ball Control: {team1_pct:5.2f}%"
1102
+ )
1103
  cv2.rectangle(
1104
  annotated_frame,
1105
  (20, annotated_frame.shape[0] - 60),
 
1125
  (20, 20),
1126
  (annotated_frame.shape[1] - 20, 90),
1127
  (255, 255, 255),
1128
+ -1,
1129
  )
1130
  cv2.putText(
1131
  annotated_frame,
 
1149
  # STEP 5: Clean Ball Path (Remove Outliers)
1150
  # ========================================
1151
  progress(0.65, desc="๐Ÿงน Cleaning ball trajectory (Step 4/6)...")
1152
+
1153
  # Convert to proper format for cleaning
1154
  path_for_cleaning = []
1155
  for coords in ball_path_raw:
 
1160
  path_for_cleaning.append(np.empty((0, 2), dtype=np.float32))
1161
  else:
1162
  path_for_cleaning.append(coords)
1163
+
1164
  # Remove outliers
1165
  cleaned_path = replace_outliers_based_on_distance(
1166
+ [
1167
+ np.array(p).reshape(-1, 2) if len(p) > 0 else np.empty((0, 2))
1168
+ for p in path_for_cleaning
1169
+ ],
1170
+ MAX_DISTANCE_THRESHOLD,
1171
+ )
1172
+
1173
+ print(
1174
+ f"โœ… Ball path cleaned: "
1175
+ f"{len([p for p in cleaned_path if len(p) > 0])} valid points"
1176
  )
 
 
1177
 
1178
  # ========================================
1179
  # STEP 6: Generate Performance Analytics
1180
  # ========================================
1181
  progress(0.75, desc="๐Ÿ“Š Generating performance analytics (Step 5/6)...")
1182
+
1183
  # Team comparison charts
1184
  comparison_fig = create_team_comparison_plot(performance_tracker)
1185
+
1186
  # Combined team heatmaps
1187
  team_heatmaps_path = "/tmp/team_heatmaps.png"
1188
  team_heatmaps = create_combined_heatmaps(performance_tracker)
1189
  cv2.imwrite(team_heatmaps_path, team_heatmaps)
1190
+
1191
  # Individual player heatmaps (top 6 by distance)
1192
  progress(0.85, desc="๐Ÿ—บ๏ธ Creating individual heatmaps...")
1193
  teams = performance_tracker.get_all_players_by_team()
1194
  top_players = []
1195
+
1196
  for team_id in [0, 1]:
1197
  if team_id in teams:
1198
  team_players = teams[team_id]
1199
+ player_distances = [
1200
+ (pid, performance_tracker.get_player_stats(pid)["total_distance_meters"])
1201
+ for pid in team_players
1202
+ ]
1203
  player_distances.sort(key=lambda x: x[1], reverse=True)
1204
  top_players.extend([pid for pid, _ in player_distances[:3]])
1205
+
1206
  individual_heatmaps = []
1207
  for pid in top_players[:6]:
1208
  heatmap = create_player_heatmap_visualization(performance_tracker, pid)
1209
  individual_heatmaps.append(heatmap)
1210
+
1211
  # Arrange individual heatmaps in grid (3 columns)
1212
  if len(individual_heatmaps) > 0:
1213
  rows = []
1214
  for i in range(0, len(individual_heatmaps), 3):
1215
+ row_maps = individual_heatmaps[i : i + 3]
1216
  if len(row_maps) == 3:
1217
  rows.append(np.hstack(row_maps))
1218
  elif len(row_maps) == 2:
1219
  rows.append(np.hstack([row_maps[0], row_maps[1]]))
1220
  else:
1221
  rows.append(row_maps[0])
1222
+
1223
  individual_grid = np.vstack(rows) if len(rows) > 1 else rows[0]
1224
  individual_heatmaps_path = "/tmp/individual_heatmaps.png"
1225
  cv2.imwrite(individual_heatmaps_path, individual_grid)
 
1234
  try:
1235
  if last_pitch_players_xy is not None:
1236
  radar_frame = create_game_style_radar(
1237
+ pitch_ball_xy=cleaned_path[-1]
1238
+ if cleaned_path
1239
+ else np.empty((0, 2)),
1240
  pitch_players_xy=last_pitch_players_xy,
1241
  players_class_id=last_players_class_id,
1242
  pitch_referees_xy=last_pitch_referees_xy,
1243
+ ball_path=cleaned_path,
1244
  )
1245
  cv2.imwrite(radar_path, radar_frame)
1246
  else:
 
1263
 
1264
  row = [
1265
  int(pid),
1266
+ int(stats["team_id"]),
1267
+ float(stats["total_distance_meters"]),
1268
+ float(stats["avg_velocity"]),
1269
+ float(stats["max_velocity"]),
1270
+ int(stats["frames_visible"]),
1271
+ int(stats["time_in_defensive_third"]),
1272
+ int(stats["time_in_middle_third"]),
1273
+ int(stats["time_in_attacking_third"]),
1274
  poss_s,
1275
  poss_pct,
1276
  ]
 
1292
  if ev_type == "pass":
1293
  desc = f"Pass #{from_tid} โ†’ #{to_tid} (Team {team_id})"
1294
  elif ev_type == "tackle":
1295
+ desc = (
1296
+ f"Tackle: #{to_tid} wins ball from #{from_tid} "
1297
+ f"(Team {team_id})"
1298
+ )
1299
  elif ev_type == "interception":
1300
+ desc = (
1301
+ f"Interception: #{to_tid} intercepts #{from_tid} "
1302
+ f"(Team {team_id})"
1303
+ )
1304
  elif ev_type == "shot":
1305
+ desc = (
1306
+ f"Shot by #{from_tid} (Team {team_id}) at {speed_kmh:.1f} km/h"
1307
+ )
1308
  elif ev_type == "clearance":
1309
  desc = f"Clearance by #{from_tid} (Team {team_id})"
1310
  else:
 
1333
  progress(0.95, desc="๐Ÿ“ Generating summary report...")
1334
 
1335
  summary_lines = ["โœ… **Analysis Complete!**\n"]
1336
+ summary_lines.append("**Video Statistics:**")
1337
  summary_lines.append(f"- Total Frames Processed: {frame_count}")
1338
  summary_lines.append(f"- Video Resolution: {width}x{height}")
1339
  summary_lines.append(f"- Frame Rate: {fps:.2f} fps")
1340
+ summary_lines.append(
1341
+ f"- Ball Trajectory Points: "
1342
+ f"{len([p for p in cleaned_path if len(p) > 0])}\n"
1343
+ )
1344
+
1345
  for team_id in [0, 1]:
1346
  if team_id not in teams:
1347
  continue
1348
+
1349
  team_name = "Team 0 (Blue)" if team_id == 0 else "Team 1 (Pink)"
1350
  summary_lines.append(f"\n**{team_name}:**")
1351
  summary_lines.append(f"- Players Tracked: {len(teams[team_id])}")
1352
+
1353
+ total_dist = sum(
1354
+ performance_tracker.get_player_stats(pid)["total_distance_meters"]
1355
+ for pid in teams[team_id]
1356
+ )
1357
  avg_dist = total_dist / len(teams[team_id]) if len(teams[team_id]) > 0 else 0
1358
  summary_lines.append(f"- Team Total Distance: {total_dist:.1f} m")
1359
+ summary_lines.append(
1360
+ f"- Average Distance per Player: {avg_dist:.1f} m"
1361
+ )
1362
+
1363
  # Top 3 performers (by distance)
1364
+ player_distances = [
1365
+ (pid, performance_tracker.get_player_stats(pid)["total_distance_meters"])
1366
+ for pid in teams[team_id]
1367
+ ]
1368
  player_distances.sort(key=lambda x: x[1], reverse=True)
1369
+
1370
+ summary_lines.append("\n **Top 3 Performers:**")
1371
  for i, (pid, dist) in enumerate(player_distances[:3], 1):
1372
  stats = performance_tracker.get_player_stats(pid)
1373
  summary_lines.append(
 
1381
  for team_id in sorted(possession_time_team.keys()):
1382
  t_sec = possession_time_team[team_id]
1383
  pct = 100.0 * t_sec / total_poss if total_poss > 0 else 0.0
1384
+ summary_lines.append(f"- Team {team_id}: {t_sec:.1f} s ({pct:.1f}%)")
1385
+
 
 
1386
  summary_lines.append("\n**Pipeline Steps Completed:**")
1387
  summary_lines.append("โœ… 1. Player crop collection")
1388
  summary_lines.append("โœ… 2. Team classifier training")
 
1390
  summary_lines.append("โœ… 4. Ball trajectory cleaning")
1391
  summary_lines.append("โœ… 5. Performance analytics generation")
1392
  summary_lines.append("โœ… 6. Visualization creation")
1393
+
1394
  summary_msg = "\n".join(summary_lines)
1395
 
1396
  progress(1.0, desc="โœ… Analysis Complete!")
1397
 
1398
  # IMPORTANT: must return 9 outputs in the same order as Gradio wiring
1399
  return (
1400
+ output_path, # video_output
1401
+ comparison_fig, # comparison_output
1402
+ team_heatmaps_path, # team_heatmaps_output
1403
  individual_heatmaps_path, # individual_heatmaps_output
1404
+ radar_path, # radar_output
1405
+ summary_msg, # status_output
1406
+ player_stats_table, # player_stats_output (Dataframe)
1407
+ events_table, # events_output (Dataframe)
1408
+ events_json_path, # events_json_output (File download)
1409
  )
1410
 
1411
  except Exception as e:
1412
  error_msg = f"โŒ Error: {str(e)}"
1413
  print(error_msg)
1414
  import traceback
1415
+
1416
  traceback.print_exc()
1417
  # Match the 9 outputs (fill with Nones/empties)
1418
  return (
1419
+ None,
1420
+ None,
1421
+ None,
1422
+ None,
1423
+ None,
1424
  error_msg,
1425
+ [],
1426
+ [],
1427
+ None,
1428
  )
1429
 
1430
+
1431
  # ==============================================
1432
  # GRADIO INTERFACE
1433
  # ==============================================
1434
+
1435
+ def run_pipeline(video) -> Tuple:
1436
+ """
1437
+ Gradio wrapper: accept the raw video object from gr.Video and
1438
+ convert it to a filesystem path for analyze_football_video().
1439
+ """
1440
+ if video is None:
1441
+ return (
1442
+ None,
1443
+ None,
1444
+ None,
1445
+ None,
1446
+ None,
1447
+ "โŒ Please upload a video file.",
1448
+ [],
1449
+ [],
1450
+ None,
1451
+ )
1452
+
1453
+ # On Spaces, Video input is usually a dict with at least a "path" key.
1454
+ if isinstance(video, dict):
1455
+ video_path = (
1456
+ video.get("path")
1457
+ or video.get("name")
1458
+ or video.get("filename")
1459
+ )
1460
+ else:
1461
+ # Fallback: if it's already a string/path-like
1462
+ video_path = str(video)
1463
+
1464
+ if not video_path:
1465
+ return (
1466
+ None,
1467
+ None,
1468
+ None,
1469
+ None,
1470
+ None,
1471
+ "โŒ Could not resolve video file path from upload.",
1472
+ [],
1473
+ [],
1474
+ None,
1475
+ )
1476
+
1477
+ return analyze_football_video(video_path)
1478
+
1479
+
1480
  with gr.Blocks(title="โšฝ Football Performance Analyzer", theme=gr.themes.Soft()) as iface:
1481
+ gr.Markdown(
1482
+ """
1483
  # โšฝ Advanced Football Video Analyzer
1484
  ### Complete Pipeline Implementation
1485
 
 
1492
  6. **Performance Analytics** - Heatmaps, stats, possession, and event detection
1493
 
1494
  Upload a football match video to get comprehensive performance analytics!
1495
+ """
1496
+ )
1497
+
1498
+ with gr.Row():
1499
+ # No "type" argument โ€“ your Gradio version does not support it
1500
  video_input = gr.Video(label="๐Ÿ“ค Upload Football Video")
1501
+
1502
  analyze_btn = gr.Button("๐Ÿš€ Start Analysis Pipeline", variant="primary", size="lg")
1503
+
1504
  with gr.Row():
1505
  status_output = gr.Textbox(label="๐Ÿ“Š Analysis Summary & Statistics", lines=25)
1506
+
1507
  with gr.Tabs():
1508
  with gr.Tab("๐Ÿ“น Annotated Video"):
1509
+ gr.Markdown(
1510
+ "### Full video with player tracking, team colors, ball detection, and events overlay"
1511
+ )
1512
  video_output = gr.Video(label="Processed Video")
1513
+
1514
  with gr.Tab("๐Ÿ“Š Performance Comparison"):
1515
  gr.Markdown("### Interactive charts comparing player performance metrics")
1516
  comparison_output = gr.Plot(label="Team Performance Metrics")
1517
+
1518
  with gr.Tab("๐Ÿ—บ๏ธ Team Heatmaps"):
1519
  gr.Markdown("### Combined activity heatmaps showing team positioning")
1520
  team_heatmaps_output = gr.Image(label="Team Activity Heatmaps")
1521
+
1522
  with gr.Tab("๐Ÿ‘ค Individual Heatmaps"):
1523
  gr.Markdown("### Top 6 players with detailed activity analysis")
1524
  individual_heatmaps_output = gr.Image(label="Top Players Heatmaps")
1525
+
1526
  with gr.Tab("๐ŸŽฎ Game Radar View"):
1527
  gr.Markdown("### Game-style tactical view with ball trail")
1528
  radar_output = gr.Image(label="Tactical Radar View")
 
1531
  gr.Markdown("### Per-player totals: distance, speeds, zones, possession")
1532
  player_stats_output = gr.Dataframe(
1533
  headers=PLAYER_STATS_HEADERS,
1534
+ col_count=len(PLAYER_STATS_HEADERS),
1535
  row_count=0,
1536
+ interactive=False,
1537
  )
1538
 
1539
  with gr.Tab("โฑ๏ธ Event Timeline"):
1540
+ gr.Markdown(
1541
+ "### Detected passes, tackles, interceptions, shots, clearances"
1542
+ )
1543
  events_output = gr.Dataframe(
1544
  headers=EVENT_HEADERS,
1545
+ col_count=len(EVENT_HEADERS),
1546
  row_count=0,
1547
+ interactive=False,
1548
  )
1549
  events_json_output = gr.File(
1550
  label="Download events JSON",
1551
+ file_types=[".json"],
1552
  )
1553
+
1554
  analyze_btn.click(
1555
+ fn=run_pipeline,
1556
  inputs=[video_input],
1557
  outputs=[
1558
+ video_output, # 1
1559
+ comparison_output, # 2
1560
  team_heatmaps_output, # 3
1561
  individual_heatmaps_output, # 4
1562
+ radar_output, # 5
1563
+ status_output, # 6
1564
+ player_stats_output, # 7
1565
+ events_output, # 8
1566
+ events_json_output, # 9
1567
+ ],
1568
  )
1569
+
1570
+ gr.Markdown(
1571
+ """
1572
  ---
1573
  ### ๐Ÿ”ง Technical Details:
1574
 
 
1597
  - Passes, tackles, interceptions, shots, clearances
1598
  - Event banner overlay in video
1599
  - Full event list downloadable as JSON
1600
+ """
1601
+ )
1602
+
1603
 
1604
  if __name__ == "__main__":
1605
+ iface.launch()