Simon9 commited on
Commit
1fec87c
·
verified ·
1 Parent(s): 301aff9

Update pipeline_full.py

Browse files
Files changed (1) hide show
  1. pipeline_full.py +196 -139
pipeline_full.py CHANGED
@@ -2,7 +2,7 @@
2
  import os
3
  import base64
4
  from io import BytesIO
5
- from typing import List, Dict, Any, Optional
6
  from collections import deque, defaultdict
7
 
8
  import numpy as np
@@ -10,6 +10,7 @@ import cv2
10
  import torch
11
  from more_itertools import chunked
12
  from PIL import Image
 
13
 
14
  import supervision as sv
15
  from inference import get_model
@@ -25,77 +26,106 @@ from sports.annotators.soccer import (
25
  draw_pitch,
26
  draw_points_on_pitch,
27
  draw_pitch_voronoi_diagram,
28
- draw_paths_on_pitch
29
  )
30
 
31
- # ------------------------------------
32
- # Global config and models
33
- # ------------------------------------
34
 
35
- ROBOFLOW_API_KEY = os.environ.get("ROBOFLOW_API_KEY")
36
- if not ROBOFLOW_API_KEY:
37
- raise RuntimeError("ROBOFLOW_API_KEY must be set in Space secrets.")
38
-
39
- PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
40
- FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
41
 
42
  BALL_ID = 0
43
  GOALKEEPER_ID = 1
44
  PLAYER_ID = 2
45
  REFEREE_ID = 3
46
 
47
- PLAYER_DETECTION_MODEL = get_model(
48
- model_id=PLAYER_DETECTION_MODEL_ID,
49
- api_key=ROBOFLOW_API_KEY
50
- )
51
- FIELD_DETECTION_MODEL = get_model(
52
- model_id=FIELD_DETECTION_MODEL_ID,
53
- api_key=ROBOFLOW_API_KEY
54
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- SIGLIP_MODEL_PATH = "google/siglip-base-patch16-224"
57
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
58
- EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH).to(DEVICE)
59
- EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- PITCH_CONFIG = SoccerPitchConfiguration()
62
 
63
- TEAM_CLASSIFIER = TeamClassifier(device="cuda")
 
 
 
 
64
 
65
- # ------------------------------------
66
- # Utility for saving images
67
- # ------------------------------------
68
 
69
  def save_image(path: str, img: np.ndarray) -> None:
70
  os.makedirs(os.path.dirname(path), exist_ok=True)
71
- # supervision uses BGR/ RGB interchangeably; assume RGB here
72
  if img.ndim == 3 and img.shape[2] == 3:
73
- # convert RGB to BGR for cv2
74
  img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
75
  else:
76
  img_bgr = img
77
  cv2.imwrite(path, img_bgr)
78
 
79
- # ------------------------------------
80
- # 1. Basic frame + detection views
81
- # ------------------------------------
82
 
83
  def step_basic_frames(video_path: str, out_dir: str) -> Dict[str, str]:
 
 
84
  frame_generator = sv.get_video_frames_generator(video_path)
85
  frame = next(frame_generator)
86
 
87
- # Raw frame
88
  raw_path = os.path.join(out_dir, "frame_raw.png")
89
  save_image(raw_path, frame)
90
 
91
- # boxes + labels
92
  box_annotator = sv.BoxAnnotator(
93
- color=sv.ColorPalette.from_hex(['#FF8C00', '#00BFFF', '#FF1493', '#FFD700']),
94
- thickness=2
95
  )
96
  label_annotator = sv.LabelAnnotator(
97
- color=sv.ColorPalette.from_hex(['#FF8C00', '#00BFFF', '#FF1493', '#FFD700']),
98
- text_color=sv.Color.from_hex('#000000')
99
  )
100
 
101
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
@@ -103,8 +133,7 @@ def step_basic_frames(video_path: str, out_dir: str) -> Dict[str, str]:
103
 
104
  labels = [
105
  f"{class_name} {confidence:.2f}"
106
- for class_name, confidence
107
- in zip(detections["class_name"], detections.confidence)
108
  ]
109
 
110
  annotated = frame.copy()
@@ -114,16 +143,15 @@ def step_basic_frames(video_path: str, out_dir: str) -> Dict[str, str]:
114
  boxes_path = os.path.join(out_dir, "frame_boxes_labels.png")
115
  save_image(boxes_path, annotated)
116
 
117
- # ball vs players using ellipse/triangle
118
  ellipse_annotator = sv.EllipseAnnotator(
119
- color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
120
- thickness=2
121
  )
122
  triangle_annotator = sv.TriangleAnnotator(
123
- color=sv.Color.from_hex('#FFD700'),
124
  base=25,
125
  height=21,
126
- outline_thickness=1
127
  )
128
 
129
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
@@ -149,21 +177,17 @@ def step_basic_frames(video_path: str, out_dir: str) -> Dict[str, str]:
149
  "ball_players": ball_players_path,
150
  }
151
 
152
- # ------------------------------------
153
- # 2. SigLIP embeddings + UMAP + KMeans + Plotly HTML
154
- # ------------------------------------
155
 
156
  def step_siglip_clustering(video_path: str, out_dir: str) -> Dict[str, str]:
157
- SOURCE_VIDEO_PATH = video_path
158
- PLAYER_ID = PLAYER_ID
159
- STRIDE = 30
160
 
161
- frame_generator = sv.get_video_frames_generator(
162
- source_path=SOURCE_VIDEO_PATH, stride=STRIDE
163
- )
164
 
165
  crops = []
166
- from tqdm import tqdm
167
  for frame in tqdm(frame_generator, desc="collecting crops (SigLIP)"):
168
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
169
  detections = sv.Detections.from_inference(result)
@@ -180,9 +204,10 @@ def step_siglip_clustering(video_path: str, out_dir: str) -> Dict[str, str]:
180
  BATCH_SIZE = 32
181
  batches = chunked(crops_pil, BATCH_SIZE)
182
  data = []
 
183
  with torch.no_grad():
184
  for batch in tqdm(batches, desc="embedding extraction"):
185
- inputs = EMBEDDINGS_PROCESSOR(images=batch, return_tensors="pt").to(DEVICE)
186
  outputs = EMBEDDINGS_MODEL(**inputs)
187
  embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
188
  data.append(embeddings)
@@ -190,27 +215,24 @@ def step_siglip_clustering(video_path: str, out_dir: str) -> Dict[str, str]:
190
  data = np.concatenate(data)
191
 
192
  REDUCER = umap.UMAP(n_components=3)
193
- CLUSTERING_MODEL = KMeans(n_clusters=2)
194
 
195
  projections = REDUCER.fit_transform(data)
196
  clusters = CLUSTERING_MODEL.fit_predict(projections)
197
 
198
- # build Plotly 3D + JS same as in notebook
199
  def pil_image_to_data_uri(image: Image.Image) -> str:
200
  buffered = BytesIO()
201
  image.save(buffered, format="PNG")
202
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
203
  return f"data:image/png;base64,{img_str}"
204
 
205
- image_data_uris = {
206
- f"image_{i}": pil_image_to_data_uri(image) for i, image in enumerate(crops_pil)
207
- }
208
  image_ids = np.array([f"image_{i}" for i in range(len(crops_pil))])
209
 
210
  traces = []
211
  unique_labels = np.unique(clusters)
212
- for unique_label in unique_labels:
213
- mask = clusters == unique_label
214
  customdata_masked = image_ids[mask]
215
  trace = go.Scatter3d(
216
  x=projections[mask][:, 0],
@@ -219,11 +241,9 @@ def step_siglip_clustering(video_path: str, out_dir: str) -> Dict[str, str]:
219
  mode="markers+text",
220
  text=clusters[mask],
221
  customdata=customdata_masked,
222
- name=str(unique_label),
223
  marker=dict(size=8),
224
- hovertemplate=(
225
- "<b>class: %{text}</b><br>image ID: %{customdata}<extra></extra>"
226
- ),
227
  )
228
  traces.append(trace)
229
 
@@ -307,22 +327,22 @@ def step_siglip_clustering(video_path: str, out_dir: str) -> Dict[str, str]:
307
  </html>
308
  """
309
 
 
310
  html_path = os.path.join(out_dir, "siglip_clusters.html")
311
  with open(html_path, "w", encoding="utf-8") as f:
312
  f.write(html_template)
313
 
314
  return {"plot_html": html_path}
315
 
316
- # ------------------------------------
317
- # 3. TeamClassifier training (same logic)
318
- # ------------------------------------
319
 
320
  def train_team_classifier_on_video(video_path: str, stride: int = 30) -> None:
321
- frame_generator = sv.get_video_frames_generator(
322
- source_path=video_path, stride=stride
323
- )
324
  crops = []
325
- from tqdm import tqdm
326
  for frame in tqdm(frame_generator, desc="collecting crops (TeamClassifier)"):
327
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
328
  detections = sv.Detections.from_inference(result)
@@ -333,13 +353,11 @@ def train_team_classifier_on_video(video_path: str, stride: int = 30) -> None:
333
  if crops:
334
  TEAM_CLASSIFIER.fit(crops)
335
 
336
- # ------------------------------------
337
- # 4. resolve_goalkeepers_team_id – your function
338
- # ------------------------------------
339
 
340
- def resolve_goalkeepers_team_id(
341
- players: sv.Detections, goalkeepers: sv.Detections
342
- ) -> np.ndarray:
 
343
  goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
344
  players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
345
  team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
@@ -351,21 +369,79 @@ def resolve_goalkeepers_team_id(
351
  goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
352
  return np.array(goalkeepers_team_id)
353
 
354
- # ------------------------------------
355
- # 5. One-frame full annotation + radar + Voronoi etc.
356
- # ------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
 
 
359
  frame_generator = sv.get_video_frames_generator(video_path)
360
  frame = next(frame_generator)
361
 
362
  ellipse_annotator = sv.EllipseAnnotator(
363
- color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
364
  thickness=2,
365
  )
366
  label_annotator = sv.LabelAnnotator(
367
- color=sv.ColorPalette.from_hex(['#00BFFF', '#FF1493', '#FFD700']),
368
- text_color=sv.Color.from_hex('#000000'),
369
  text_position=sv.Position.BOTTOM_CENTER,
370
  )
371
  triangle_annotator = sv.TriangleAnnotator(
@@ -375,7 +451,6 @@ def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
375
  tracker = sv.ByteTrack()
376
  tracker.reset()
377
 
378
- # detect ball, goalkeeper, player, referee
379
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
380
  detections = sv.Detections.from_inference(result)
381
 
@@ -391,9 +466,10 @@ def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
391
  referees_detections = all_detections[all_detections.class_id == REFEREE_ID]
392
 
393
  players_crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
394
- players_detections.class_id = TEAM_CLASSIFIER.predict(players_crops)
 
395
 
396
- if len(goalkeepers_detections) > 0:
397
  goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
398
  players_detections, goalkeepers_detections
399
  )
@@ -408,9 +484,7 @@ def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
408
  all_detections2.class_id = all_detections2.class_id.astype(int)
409
 
410
  annotated_frame = frame.copy()
411
- annotated_frame = ellipse_annotator.annotate(
412
- scene=annotated_frame, detections=all_detections2
413
- )
414
  annotated_frame = label_annotator.annotate(
415
  scene=annotated_frame, detections=all_detections2, labels=labels
416
  )
@@ -418,10 +492,11 @@ def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
418
  scene=annotated_frame, detections=ball_detections
419
  )
420
 
 
421
  annotated_path = os.path.join(out_dir, "frame_advanced.png")
422
  save_image(annotated_path, annotated_frame)
423
 
424
- # Pitch keypoints, radar, Voronoi, etc. – same as notebook logic
425
  result = FIELD_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
426
  key_points = sv.KeyPoints.from_inference(result)
427
 
@@ -429,9 +504,7 @@ def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
429
  frame_reference_points = key_points.xy[0][filt]
430
  pitch_reference_points = np.array(PITCH_CONFIG.vertices)[filt]
431
 
432
- transformer = ViewTransformer(
433
- source=frame_reference_points, target=pitch_reference_points
434
- )
435
 
436
  frame_ball_xy = ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
437
  pitch_ball_xy = transformer.transform_points(points=frame_ball_xy)
@@ -442,7 +515,6 @@ def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
442
  referees_xy = referees_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
443
  pitch_referees_xy = transformer.transform_points(points=referees_xy)
444
 
445
- # radar view
446
  radar = draw_pitch(PITCH_CONFIG)
447
  radar = draw_points_on_pitch(
448
  config=PITCH_CONFIG,
@@ -479,7 +551,6 @@ def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
479
  radar_path = os.path.join(out_dir, "radar_view.png")
480
  save_image(radar_path, radar)
481
 
482
- # Voronoi classic
483
  vor = draw_pitch(PITCH_CONFIG)
484
  vor = draw_pitch_voronoi_diagram(
485
  config=PITCH_CONFIG,
@@ -492,11 +563,8 @@ def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
492
  vor_path = os.path.join(out_dir, "voronoi.png")
493
  save_image(vor_path, vor)
494
 
495
- # Blended Voronoi (your custom function)
496
  blended = draw_pitch(
497
- config=PITCH_CONFIG,
498
- background_color=sv.Color.WHITE,
499
- line_color=sv.Color.BLACK,
500
  )
501
  blended = draw_pitch_voronoi_diagram_2(
502
  config=PITCH_CONFIG,
@@ -543,16 +611,12 @@ def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
543
  "voronoi_blended": blended_path,
544
  }
545
 
546
- # ------------------------------------
547
- # 6. Ball path & outlier cleaning – same logic
548
- # ------------------------------------
549
 
550
- def replace_outliers_based_on_distance(
551
- positions: List[np.ndarray], distance_threshold: float
552
- ) -> List[np.ndarray]:
553
- from typing import Union
554
 
555
- last_valid_position: Union[np.ndarray, None] = None
 
556
  cleaned_positions: List[np.ndarray] = []
557
 
558
  for position in positions:
@@ -572,7 +636,10 @@ def replace_outliers_based_on_distance(
572
 
573
  return cleaned_positions
574
 
 
575
  def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
 
 
576
  MAXLEN = 5
577
  MAX_DISTANCE_THRESHOLD = 500
578
 
@@ -582,8 +649,7 @@ def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
582
  path_raw: List[np.ndarray] = []
583
  M = deque(maxlen=MAXLEN)
584
 
585
- from tqdm import tqdm
586
- for frame in tqdm(frame_generator, total=video_info.total_frames):
587
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
588
  detections = sv.Detections.from_inference(result)
589
 
@@ -603,9 +669,7 @@ def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
603
  M.append(transformer.m)
604
  transformer.m = np.mean(np.array(M), axis=0)
605
 
606
- frame_ball_xy = ball_detections.get_anchors_coordinates(
607
- sv.Position.BOTTOM_CENTER
608
- )
609
  pitch_ball_xy = transformer.transform_points(points=frame_ball_xy)
610
 
611
  path_raw.append(pitch_ball_xy)
@@ -618,7 +682,6 @@ def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
618
 
619
  path_clean = replace_outliers_based_on_distance(path, MAX_DISTANCE_THRESHOLD)
620
 
621
- # draw raw
622
  raw_pitch = draw_pitch(PITCH_CONFIG)
623
  raw_pitch = draw_paths_on_pitch(
624
  config=PITCH_CONFIG, paths=[path], color=sv.Color.WHITE, pitch=raw_pitch
@@ -626,7 +689,6 @@ def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
626
  raw_path_img = os.path.join(out_dir, "ball_path_raw.png")
627
  save_image(raw_path_img, raw_pitch)
628
 
629
- # draw cleaned
630
  clean_pitch = draw_pitch(PITCH_CONFIG)
631
  clean_pitch = draw_paths_on_pitch(
632
  config=PITCH_CONFIG, paths=[path_clean], color=sv.Color.WHITE, pitch=clean_pitch
@@ -634,7 +696,6 @@ def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
634
  cleaned_path_img = os.path.join(out_dir, "ball_path_cleaned.png")
635
  save_image(cleaned_path_img, clean_pitch)
636
 
637
- # return coords as simple list for JSON
638
  coords_clean = [
639
  coords.tolist() if len(coords) > 0 else [] for coords in path_clean
640
  ]
@@ -645,11 +706,13 @@ def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
645
  "ball_path_cleaned_coords": coords_clean,
646
  }
647
 
648
- # ------------------------------------
649
- # 7. Stats-only process_video (like your FastAPI helper)
650
- # ------------------------------------
651
 
652
  def process_video_stats(video_path: str) -> Dict[str, Any]:
 
 
653
  tracker = sv.ByteTrack()
654
  tracker.reset()
655
  stats = {
@@ -675,29 +738,24 @@ def process_video_stats(video_path: str) -> Dict[str, Any]:
675
  stats["distance_covered"] = dict(stats["distance_covered"])
676
  return stats
677
 
678
- # ------------------------------------
679
- # 8. Entry point: run full pipeline on a video
680
- # ------------------------------------
681
 
682
  def run_full_pipeline(video_path: str, job_dir: str) -> Dict[str, Any]:
 
 
 
 
 
683
  os.makedirs(job_dir, exist_ok=True)
684
 
685
- # 1) SigLIP & TeamClassifier training as in NB
686
- step_siglip_clustering(video_path, os.path.join(job_dir, "siglip"))
687
  train_team_classifier_on_video(video_path)
688
 
689
- # 2) Basic visualizations
690
  basic_paths = step_basic_frames(video_path, os.path.join(job_dir, "frames"))
691
-
692
- # 3) Advanced one-frame analytics (radar, Voronoi, etc.)
693
- adv_paths = step_single_frame_advanced(
694
- video_path, os.path.join(job_dir, "advanced")
695
- )
696
-
697
- # 4) Ball path & heatmap
698
  ball_paths = step_ball_path(video_path, os.path.join(job_dir, "ball_path"))
699
-
700
- # 5) Stats
701
  stats = process_video_stats(video_path)
702
 
703
  return {
@@ -705,6 +763,5 @@ def run_full_pipeline(video_path: str, job_dir: str) -> Dict[str, Any]:
705
  "advanced": adv_paths,
706
  "ball": ball_paths,
707
  "stats": stats,
708
- # SigLIP HTML path known: job_dir/siglip/siglip_clusters.html
709
- "siglip_html": os.path.join(job_dir, "siglip", "siglip_clusters.html"),
710
  }
 
2
  import os
3
  import base64
4
  from io import BytesIO
5
+ from typing import List, Dict, Any
6
  from collections import deque, defaultdict
7
 
8
  import numpy as np
 
10
  import torch
11
  from more_itertools import chunked
12
  from PIL import Image
13
+ from tqdm import tqdm
14
 
15
  import supervision as sv
16
  from inference import get_model
 
26
  draw_pitch,
27
  draw_points_on_pitch,
28
  draw_pitch_voronoi_diagram,
29
+ draw_paths_on_pitch,
30
  )
31
 
32
+ # ------------------------------------------------------------------
33
+ # Globals will be initialized lazily so build/startup doesn't crash
34
+ # ------------------------------------------------------------------
35
 
36
+ PLAYER_DETECTION_MODEL = None
37
+ FIELD_DETECTION_MODEL = None
38
+ EMBEDDINGS_MODEL = None
39
+ EMBEDDINGS_PROCESSOR = None
40
+ TEAM_CLASSIFIER = None
41
+ PITCH_CONFIG = None
42
 
43
  BALL_ID = 0
44
  GOALKEEPER_ID = 1
45
  PLAYER_ID = 2
46
  REFEREE_ID = 3
47
 
48
+ MODELS_READY = False
49
+
50
+
51
+ def ensure_models_loaded():
52
+ """
53
+ Lazily load all heavy models and config.
54
+ Called at the start of run_full_pipeline().
55
+ """
56
+ global PLAYER_DETECTION_MODEL, FIELD_DETECTION_MODEL
57
+ global EMBEDDINGS_MODEL, EMBEDDINGS_PROCESSOR
58
+ global TEAM_CLASSIFIER, PITCH_CONFIG, MODELS_READY
59
+
60
+ if MODELS_READY:
61
+ return
62
+
63
+ roboflow_api_key = os.environ.get("ROBOFLOW_API_KEY")
64
+ if not roboflow_api_key:
65
+ raise RuntimeError(
66
+ "ROBOFLOW_API_KEY env var must be set in the Space secrets "
67
+ "(Settings → Variables and secrets)."
68
+ )
69
+
70
+ # Roboflow models
71
+ PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
72
+ FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
73
 
74
+ PLAYER_DETECTION_MODEL = get_model(
75
+ model_id=PLAYER_DETECTION_MODEL_ID, api_key=roboflow_api_key
76
+ )
77
+ FIELD_DETECTION_MODEL = get_model(
78
+ model_id=FIELD_DETECTION_MODEL_ID, api_key=roboflow_api_key
79
+ )
80
+
81
+ # SigLIP embeddings
82
+ SIGLIP_MODEL_PATH = "google/siglip-base-patch16-224"
83
+ device = "cuda" if torch.cuda.is_available() else "cpu"
84
+ EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH).to(device)
85
+ EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)
86
+
87
+ # Pitch + TeamClassifier
88
+ PITCH_CONFIG = SoccerPitchConfiguration()
89
+ TEAM_CLASSIFIER = TeamClassifier(device="cuda" if torch.cuda.is_available() else "cpu")
90
+
91
+ MODELS_READY = True
92
 
 
93
 
94
+ def get_device():
95
+ return "cuda" if torch.cuda.is_available() else "cpu"
96
+
97
+
98
+ # -------------------- utility for saving images --------------------
99
 
 
 
 
100
 
101
  def save_image(path: str, img: np.ndarray) -> None:
102
  os.makedirs(os.path.dirname(path), exist_ok=True)
 
103
  if img.ndim == 3 and img.shape[2] == 3:
 
104
  img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
105
  else:
106
  img_bgr = img
107
  cv2.imwrite(path, img_bgr)
108
 
109
+
110
+ # -------------------- 1. basic frames & detections --------------------
111
+
112
 
113
  def step_basic_frames(video_path: str, out_dir: str) -> Dict[str, str]:
114
+ ensure_models_loaded()
115
+
116
  frame_generator = sv.get_video_frames_generator(video_path)
117
  frame = next(frame_generator)
118
 
 
119
  raw_path = os.path.join(out_dir, "frame_raw.png")
120
  save_image(raw_path, frame)
121
 
 
122
  box_annotator = sv.BoxAnnotator(
123
+ color=sv.ColorPalette.from_hex(["#FF8C00", "#00BFFF", "#FF1493", "#FFD700"]),
124
+ thickness=2,
125
  )
126
  label_annotator = sv.LabelAnnotator(
127
+ color=sv.ColorPalette.from_hex(["#FF8C00", "#00BFFF", "#FF1493", "#FFD700"]),
128
+ text_color=sv.Color.from_hex("#000000"),
129
  )
130
 
131
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
 
133
 
134
  labels = [
135
  f"{class_name} {confidence:.2f}"
136
+ for class_name, confidence in zip(detections["class_name"], detections.confidence)
 
137
  ]
138
 
139
  annotated = frame.copy()
 
143
  boxes_path = os.path.join(out_dir, "frame_boxes_labels.png")
144
  save_image(boxes_path, annotated)
145
 
 
146
  ellipse_annotator = sv.EllipseAnnotator(
147
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
148
+ thickness=2,
149
  )
150
  triangle_annotator = sv.TriangleAnnotator(
151
+ color=sv.Color.from_hex("#FFD700"),
152
  base=25,
153
  height=21,
154
+ outline_thickness=1,
155
  )
156
 
157
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
 
177
  "ball_players": ball_players_path,
178
  }
179
 
180
+
181
+ # -------------------- 2. SigLIP + UMAP + KMeans + HTML --------------------
182
+
183
 
184
  def step_siglip_clustering(video_path: str, out_dir: str) -> Dict[str, str]:
185
+ ensure_models_loaded()
 
 
186
 
187
+ stride = 30
188
+ frame_generator = sv.get_video_frames_generator(source_path=video_path, stride=stride)
 
189
 
190
  crops = []
 
191
  for frame in tqdm(frame_generator, desc="collecting crops (SigLIP)"):
192
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
193
  detections = sv.Detections.from_inference(result)
 
204
  BATCH_SIZE = 32
205
  batches = chunked(crops_pil, BATCH_SIZE)
206
  data = []
207
+ device = get_device()
208
  with torch.no_grad():
209
  for batch in tqdm(batches, desc="embedding extraction"):
210
+ inputs = EMBEDDINGS_PROCESSOR(images=batch, return_tensors="pt").to(device)
211
  outputs = EMBEDDINGS_MODEL(**inputs)
212
  embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
213
  data.append(embeddings)
 
215
  data = np.concatenate(data)
216
 
217
  REDUCER = umap.UMAP(n_components=3)
218
+ CLUSTERING_MODEL = KMeans(n_clusters=2, n_init="auto")
219
 
220
  projections = REDUCER.fit_transform(data)
221
  clusters = CLUSTERING_MODEL.fit_predict(projections)
222
 
 
223
  def pil_image_to_data_uri(image: Image.Image) -> str:
224
  buffered = BytesIO()
225
  image.save(buffered, format="PNG")
226
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
227
  return f"data:image/png;base64,{img_str}"
228
 
229
+ image_data_uris = {f"image_{i}": pil_image_to_data_uri(img) for i, img in enumerate(crops_pil)}
 
 
230
  image_ids = np.array([f"image_{i}" for i in range(len(crops_pil))])
231
 
232
  traces = []
233
  unique_labels = np.unique(clusters)
234
+ for lbl in unique_labels:
235
+ mask = clusters == lbl
236
  customdata_masked = image_ids[mask]
237
  trace = go.Scatter3d(
238
  x=projections[mask][:, 0],
 
241
  mode="markers+text",
242
  text=clusters[mask],
243
  customdata=customdata_masked,
244
+ name=str(lbl),
245
  marker=dict(size=8),
246
+ hovertemplate="<b>class: %{text}</b><br>image ID: %{customdata}<extra></extra>",
 
 
247
  )
248
  traces.append(trace)
249
 
 
327
  </html>
328
  """
329
 
330
+ os.makedirs(out_dir, exist_ok=True)
331
  html_path = os.path.join(out_dir, "siglip_clusters.html")
332
  with open(html_path, "w", encoding="utf-8") as f:
333
  f.write(html_template)
334
 
335
  return {"plot_html": html_path}
336
 
337
+
338
+ # -------------------- 3. TeamClassifier training --------------------
339
+
340
 
341
  def train_team_classifier_on_video(video_path: str, stride: int = 30) -> None:
342
+ ensure_models_loaded()
343
+
344
+ frame_generator = sv.get_video_frames_generator(source_path=video_path, stride=stride)
345
  crops = []
 
346
  for frame in tqdm(frame_generator, desc="collecting crops (TeamClassifier)"):
347
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
348
  detections = sv.Detections.from_inference(result)
 
353
  if crops:
354
  TEAM_CLASSIFIER.fit(crops)
355
 
 
 
 
356
 
357
+ # -------------------- 4. goalkeeper team resolution --------------------
358
+
359
+
360
+ def resolve_goalkeepers_team_id(players: sv.Detections, goalkeepers: sv.Detections) -> np.ndarray:
361
  goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
362
  players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
363
  team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
 
369
  goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
370
  return np.array(goalkeepers_team_id)
371
 
372
+
373
+ # -------------------- 5. Voronoi blend helper (your function) --------------------
374
+
375
+
376
+ def draw_pitch_voronoi_diagram_2(
377
+ config: SoccerPitchConfiguration,
378
+ team_1_xy: np.ndarray,
379
+ team_2_xy: np.ndarray,
380
+ team_1_color: sv.Color = sv.Color.RED,
381
+ team_2_color: sv.Color = sv.Color.WHITE,
382
+ opacity: float = 0.5,
383
+ padding: int = 50,
384
+ scale: float = 0.1,
385
+ pitch: np.ndarray | None = None,
386
+ ) -> np.ndarray:
387
+ if pitch is None:
388
+ pitch = draw_pitch(config=config, padding=padding, scale=scale)
389
+
390
+ scaled_width = int(config.width * scale)
391
+ scaled_length = int(config.length * scale)
392
+
393
+ voronoi = np.zeros_like(pitch, dtype=np.uint8)
394
+
395
+ team_1_color_bgr = np.array(team_1_color.as_bgr(), dtype=np.uint8)
396
+ team_2_color_bgr = np.array(team_2_color.as_bgr(), dtype=np.uint8)
397
+
398
+ y_coordinates, x_coordinates = np.indices((scaled_width + 2 * padding, scaled_length + 2 * padding))
399
+ y_coordinates -= padding
400
+ x_coordinates -= padding
401
+
402
+ def calculate_distances(xy, x_coordinates, y_coordinates):
403
+ return np.sqrt(
404
+ (xy[:, 0][:, None, None] * scale - x_coordinates) ** 2
405
+ + (xy[:, 1][:, None, None] * scale - y_coordinates) ** 2
406
+ )
407
+
408
+ distances_team_1 = calculate_distances(team_1_xy, x_coordinates, y_coordinates)
409
+ distances_team_2 = calculate_distances(team_2_xy, x_coordinates, y_coordinates)
410
+
411
+ min_distances_team_1 = np.min(distances_team_1, axis=0)
412
+ min_distances_team_2 = np.min(distances_team_2, axis=0)
413
+
414
+ steepness = 15
415
+ distance_ratio = min_distances_team_2 / np.clip(
416
+ min_distances_team_1 + min_distances_team_2, a_min=1e-5, a_max=None
417
+ )
418
+ blend_factor = np.tanh((distance_ratio - 0.5) * steepness) * 0.5 + 0.5
419
+
420
+ for c in range(3):
421
+ voronoi[:, :, c] = (
422
+ blend_factor * team_1_color_bgr[c] + (1 - blend_factor) * team_2_color_bgr[c]
423
+ ).astype(np.uint8)
424
+
425
+ overlay = cv2.addWeighted(voronoi, opacity, pitch, 1 - opacity, 0)
426
+ return overlay
427
+
428
+
429
+ # -------------------- 6. single-frame advanced views --------------------
430
+
431
 
432
  def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
433
+ ensure_models_loaded()
434
+
435
  frame_generator = sv.get_video_frames_generator(video_path)
436
  frame = next(frame_generator)
437
 
438
  ellipse_annotator = sv.EllipseAnnotator(
439
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
440
  thickness=2,
441
  )
442
  label_annotator = sv.LabelAnnotator(
443
+ color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
444
+ text_color=sv.Color.from_hex("#000000"),
445
  text_position=sv.Position.BOTTOM_CENTER,
446
  )
447
  triangle_annotator = sv.TriangleAnnotator(
 
451
  tracker = sv.ByteTrack()
452
  tracker.reset()
453
 
 
454
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
455
  detections = sv.Detections.from_inference(result)
456
 
 
466
  referees_detections = all_detections[all_detections.class_id == REFEREE_ID]
467
 
468
  players_crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
469
+ if players_crops:
470
+ players_detections.class_id = TEAM_CLASSIFIER.predict(players_crops)
471
 
472
+ if len(goalkeepers_detections) > 0 and len(players_detections) > 0:
473
  goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
474
  players_detections, goalkeepers_detections
475
  )
 
484
  all_detections2.class_id = all_detections2.class_id.astype(int)
485
 
486
  annotated_frame = frame.copy()
487
+ annotated_frame = ellipse_annotator.annotate(scene=annotated_frame, detections=all_detections2)
 
 
488
  annotated_frame = label_annotator.annotate(
489
  scene=annotated_frame, detections=all_detections2, labels=labels
490
  )
 
492
  scene=annotated_frame, detections=ball_detections
493
  )
494
 
495
+ os.makedirs(out_dir, exist_ok=True)
496
  annotated_path = os.path.join(out_dir, "frame_advanced.png")
497
  save_image(annotated_path, annotated_frame)
498
 
499
+ # Pitch + radar + Voronoi
500
  result = FIELD_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
501
  key_points = sv.KeyPoints.from_inference(result)
502
 
 
504
  frame_reference_points = key_points.xy[0][filt]
505
  pitch_reference_points = np.array(PITCH_CONFIG.vertices)[filt]
506
 
507
+ transformer = ViewTransformer(source=frame_reference_points, target=pitch_reference_points)
 
 
508
 
509
  frame_ball_xy = ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
510
  pitch_ball_xy = transformer.transform_points(points=frame_ball_xy)
 
515
  referees_xy = referees_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
516
  pitch_referees_xy = transformer.transform_points(points=referees_xy)
517
 
 
518
  radar = draw_pitch(PITCH_CONFIG)
519
  radar = draw_points_on_pitch(
520
  config=PITCH_CONFIG,
 
551
  radar_path = os.path.join(out_dir, "radar_view.png")
552
  save_image(radar_path, radar)
553
 
 
554
  vor = draw_pitch(PITCH_CONFIG)
555
  vor = draw_pitch_voronoi_diagram(
556
  config=PITCH_CONFIG,
 
563
  vor_path = os.path.join(out_dir, "voronoi.png")
564
  save_image(vor_path, vor)
565
 
 
566
  blended = draw_pitch(
567
+ config=PITCH_CONFIG, background_color=sv.Color.WHITE, line_color=sv.Color.BLACK
 
 
568
  )
569
  blended = draw_pitch_voronoi_diagram_2(
570
  config=PITCH_CONFIG,
 
611
  "voronoi_blended": blended_path,
612
  }
613
 
 
 
 
614
 
615
+ # -------------------- 7. ball path & cleaning --------------------
616
+
 
 
617
 
618
+ def replace_outliers_based_on_distance(positions: List[np.ndarray], distance_threshold: float) -> List[np.ndarray]:
619
+ last_valid_position = None
620
  cleaned_positions: List[np.ndarray] = []
621
 
622
  for position in positions:
 
636
 
637
  return cleaned_positions
638
 
639
+
640
  def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
641
+ ensure_models_loaded()
642
+
643
  MAXLEN = 5
644
  MAX_DISTANCE_THRESHOLD = 500
645
 
 
649
  path_raw: List[np.ndarray] = []
650
  M = deque(maxlen=MAXLEN)
651
 
652
+ for frame in tqdm(frame_generator, total=video_info.total_frames, desc="ball path"):
 
653
  result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
654
  detections = sv.Detections.from_inference(result)
655
 
 
669
  M.append(transformer.m)
670
  transformer.m = np.mean(np.array(M), axis=0)
671
 
672
+ frame_ball_xy = ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
 
 
673
  pitch_ball_xy = transformer.transform_points(points=frame_ball_xy)
674
 
675
  path_raw.append(pitch_ball_xy)
 
682
 
683
  path_clean = replace_outliers_based_on_distance(path, MAX_DISTANCE_THRESHOLD)
684
 
 
685
  raw_pitch = draw_pitch(PITCH_CONFIG)
686
  raw_pitch = draw_paths_on_pitch(
687
  config=PITCH_CONFIG, paths=[path], color=sv.Color.WHITE, pitch=raw_pitch
 
689
  raw_path_img = os.path.join(out_dir, "ball_path_raw.png")
690
  save_image(raw_path_img, raw_pitch)
691
 
 
692
  clean_pitch = draw_pitch(PITCH_CONFIG)
693
  clean_pitch = draw_paths_on_pitch(
694
  config=PITCH_CONFIG, paths=[path_clean], color=sv.Color.WHITE, pitch=clean_pitch
 
696
  cleaned_path_img = os.path.join(out_dir, "ball_path_cleaned.png")
697
  save_image(cleaned_path_img, clean_pitch)
698
 
 
699
  coords_clean = [
700
  coords.tolist() if len(coords) > 0 else [] for coords in path_clean
701
  ]
 
706
  "ball_path_cleaned_coords": coords_clean,
707
  }
708
 
709
+
710
+ # -------------------- 8. stats-only process_video --------------------
711
+
712
 
713
  def process_video_stats(video_path: str) -> Dict[str, Any]:
714
+ ensure_models_loaded()
715
+
716
  tracker = sv.ByteTrack()
717
  tracker.reset()
718
  stats = {
 
738
  stats["distance_covered"] = dict(stats["distance_covered"])
739
  return stats
740
 
741
+
742
+ # -------------------- 9. full pipeline entrypoint --------------------
743
+
744
 
745
  def run_full_pipeline(video_path: str, job_dir: str) -> Dict[str, Any]:
746
+ """
747
+ Run the full notebook-equivalent pipeline on a video and save all artifacts
748
+ into job_dir. Returns paths + stats for the FastAPI app.
749
+ """
750
+ ensure_models_loaded()
751
  os.makedirs(job_dir, exist_ok=True)
752
 
753
+ siglip_out = step_siglip_clustering(video_path, os.path.join(job_dir, "siglip"))
 
754
  train_team_classifier_on_video(video_path)
755
 
 
756
  basic_paths = step_basic_frames(video_path, os.path.join(job_dir, "frames"))
757
+ adv_paths = step_single_frame_advanced(video_path, os.path.join(job_dir, "advanced"))
 
 
 
 
 
 
758
  ball_paths = step_ball_path(video_path, os.path.join(job_dir, "ball_path"))
 
 
759
  stats = process_video_stats(video_path)
760
 
761
  return {
 
763
  "advanced": adv_paths,
764
  "ball": ball_paths,
765
  "stats": stats,
766
+ "siglip_html": siglip_out["plot_html"],
 
767
  }