# pipeline_full.py
import os
import json
import base64
from io import BytesIO
from typing import List, Dict, Any, Optional
from collections import deque, defaultdict
# Silence optional-model warnings from `inference`
os.environ["CORE_MODEL_SAM_ENABLED"] = "False"
os.environ["CORE_MODEL_SAM2_ENABLED"] = "False"
os.environ["CORE_MODEL_SAM3_ENABLED"] = "False"
os.environ["CORE_MODEL_GAZE_ENABLED"] = "False"
os.environ["CORE_MODEL_GROUNDINGDINO_ENABLED"] = "False"
os.environ["CORE_MODEL_YOLO_WORLD_ENABLED"] = "False"
import numpy as np
import cv2
import torch
from more_itertools import chunked
from PIL import Image
from tqdm import tqdm
import supervision as sv
from inference import get_model
from transformers import AutoProcessor, SiglipVisionModel
import umap
from sklearn.cluster import KMeans
import plotly.graph_objects as go
from sports.common.team import TeamClassifier
from sports.common.view import ViewTransformer
from sports.configs.soccer import SoccerPitchConfiguration
from sports.annotators.soccer import (
draw_pitch,
draw_points_on_pitch,
draw_pitch_voronoi_diagram,
draw_paths_on_pitch,
)
# ------------------------------------------------------------------
# Globals – initialized lazily so build/startup doesn't crash
# ------------------------------------------------------------------
PLAYER_DETECTION_MODEL = None
FIELD_DETECTION_MODEL = None
EMBEDDINGS_MODEL = None
EMBEDDINGS_PROCESSOR = None
TEAM_CLASSIFIER = None
PITCH_CONFIG = None
BALL_ID = 0
GOALKEEPER_ID = 1
PLAYER_ID = 2
REFEREE_ID = 3
MODELS_READY = False
# progress tracking
CURRENT_JOB_DIR: Optional[str] = None
def set_job_dir(job_dir: str):
global CURRENT_JOB_DIR
CURRENT_JOB_DIR = job_dir
def update_progress(stage: str, progress: float, message: str = ""):
"""
Write a small JSON status file in the current job dir so the UI can poll.
"""
if not CURRENT_JOB_DIR:
return
status = {
"stage": stage,
"progress": float(progress),
"message": message,
}
os.makedirs(CURRENT_JOB_DIR, exist_ok=True)
status_path = os.path.join(CURRENT_JOB_DIR, "status.json")
with open(status_path, "w", encoding="utf-8") as f:
json.dump(status, f)
def ensure_models_loaded():
"""
Lazily load all heavy models and config.
Called at the start of run_full_pipeline().
"""
global PLAYER_DETECTION_MODEL, FIELD_DETECTION_MODEL
global EMBEDDINGS_MODEL, EMBEDDINGS_PROCESSOR
global TEAM_CLASSIFIER, PITCH_CONFIG, MODELS_READY
if MODELS_READY:
return
roboflow_api_key = os.environ.get("ROBOFLOW_API_KEY")
if not roboflow_api_key:
raise RuntimeError(
"ROBOFLOW_API_KEY env var must be set in the Space secrets "
"(Settings → Variables and secrets)."
)
# Roboflow models
PLAYER_DETECTION_MODEL_ID = "football-players-detection-3zvbc/11"
FIELD_DETECTION_MODEL_ID = "football-field-detection-f07vi/14"
PLAYER_DETECTION_MODEL = get_model(
model_id=PLAYER_DETECTION_MODEL_ID, api_key=roboflow_api_key
)
FIELD_DETECTION_MODEL = get_model(
model_id=FIELD_DETECTION_MODEL_ID, api_key=roboflow_api_key
)
# SigLIP embeddings
SIGLIP_MODEL_PATH = "google/siglip-base-patch16-224"
device = get_device()
EMBEDDINGS_MODEL = SiglipVisionModel.from_pretrained(SIGLIP_MODEL_PATH).to(device)
EMBEDDINGS_PROCESSOR = AutoProcessor.from_pretrained(SIGLIP_MODEL_PATH)
# Pitch + TeamClassifier
PITCH_CONFIG = SoccerPitchConfiguration()
TEAM_CLASSIFIER = TeamClassifier(device="cuda" if torch.cuda.is_available() else "cpu")
MODELS_READY = True
def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"
# -------------------- utility for saving images --------------------
def save_image(path: str, img: np.ndarray) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
if img.ndim == 3 and img.shape[2] == 3:
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
else:
img_bgr = img
cv2.imwrite(path, img_bgr)
# -------------------- 1. basic frames & detections --------------------
def step_basic_frames(video_path: str, out_dir: str) -> Dict[str, str]:
ensure_models_loaded()
frame_generator = sv.get_video_frames_generator(video_path)
frame = next(frame_generator)
raw_path = os.path.join(out_dir, "frame_raw.png")
save_image(raw_path, frame)
box_annotator = sv.BoxAnnotator(
color=sv.ColorPalette.from_hex(["#FF8C00", "#00BFFF", "#FF1493", "#FFD700"]),
thickness=2,
)
label_annotator = sv.LabelAnnotator(
color=sv.ColorPalette.from_hex(["#FF8C00", "#00BFFF", "#FF1493", "#FFD700"]),
text_color=sv.Color.from_hex("#000000"),
)
result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
detections = sv.Detections.from_inference(result)
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence in zip(detections["class_name"], detections.confidence)
]
annotated = frame.copy()
annotated = box_annotator.annotate(scene=annotated, detections=detections)
annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels)
boxes_path = os.path.join(out_dir, "frame_boxes_labels.png")
save_image(boxes_path, annotated)
ellipse_annotator = sv.EllipseAnnotator(
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
thickness=2,
)
triangle_annotator = sv.TriangleAnnotator(
color=sv.Color.from_hex("#FFD700"),
base=25,
height=21,
outline_thickness=1,
)
result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
detections = sv.Detections.from_inference(result)
ball_detections = detections[detections.class_id == BALL_ID]
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
all_detections = detections[detections.class_id != BALL_ID]
all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
all_detections.class_id -= 1
annotated2 = frame.copy()
annotated2 = ellipse_annotator.annotate(scene=annotated2, detections=all_detections)
annotated2 = triangle_annotator.annotate(scene=annotated2, detections=ball_detections)
ball_players_path = os.path.join(out_dir, "frame_ball_players.png")
save_image(ball_players_path, annotated2)
return {
"raw_frame": raw_path,
"boxes_labels": boxes_path,
"ball_players": ball_players_path,
}
# -------------------- 2. SigLIP + UMAP + KMeans + HTML --------------------
def step_siglip_clustering(video_path: str, out_dir: str) -> Dict[str, str]:
ensure_models_loaded()
stride = 30
frame_generator = sv.get_video_frames_generator(source_path=video_path, stride=stride)
crops = []
for frame in tqdm(frame_generator, desc="collecting crops (SigLIP)"):
result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
detections = sv.Detections.from_inference(result)
detections = detections.with_nms(threshold=0.5, class_agnostic=True)
detections = detections[detections.class_id == PLAYER_ID]
players_crops = [sv.crop_image(frame, xyxy) for xyxy in detections.xyxy]
crops += players_crops
if not crops:
return {"plot_html": ""}
crops_pil = [sv.cv2_to_pillow(c) for c in crops]
BATCH_SIZE = 32
batches = chunked(crops_pil, BATCH_SIZE)
data = []
device = get_device()
with torch.no_grad():
for batch in tqdm(batches, desc="embedding extraction"):
inputs = EMBEDDINGS_PROCESSOR(images=batch, return_tensors="pt").to(device)
outputs = EMBEDDINGS_MODEL(**inputs)
embeddings = torch.mean(outputs.last_hidden_state, dim=1).cpu().numpy()
data.append(embeddings)
data = np.concatenate(data)
REDUCER = umap.UMAP(n_components=3)
CLUSTERING_MODEL = KMeans(n_clusters=2, n_init="auto")
projections = REDUCER.fit_transform(data)
clusters = CLUSTERING_MODEL.fit_predict(projections)
def pil_image_to_data_uri(image: Image.Image) -> str:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64,{img_str}"
image_data_uris = {f"image_{i}": pil_image_to_data_uri(img) for i, img in enumerate(crops_pil)}
image_ids = np.array([f"image_{i}" for i in range(len(crops_pil))])
traces = []
unique_labels = np.unique(clusters)
for lbl in unique_labels:
mask = clusters == lbl
customdata_masked = image_ids[mask]
trace = go.Scatter3d(
x=projections[mask][:, 0],
y=projections[mask][:, 1],
z=projections[mask][:, 2],
mode="markers+text",
text=clusters[mask],
customdata=customdata_masked,
name=str(lbl),
marker=dict(size=8),
hovertemplate="class: %{text}
image ID: %{customdata}",
)
traces.append(trace)
min_val = np.min(projections)
max_val = np.max(projections)
padding = (max_val - min_val) * 0.05
axis_range = [min_val - padding, max_val + padding]
fig = go.Figure(data=traces)
fig.update_layout(
scene=dict(
xaxis=dict(title="X", range=axis_range),
yaxis=dict(title="Y", range=axis_range),
zaxis=dict(title="Z", range=axis_range),
aspectmode="cube",
),
width=1000,
height=1000,
showlegend=False,
)
plotly_div = fig.to_html(full_html=False, include_plotlyjs=False, div_id="scatter-plot-3d")
javascript_code = f"""
"""
html_template = f"""
{plotly_div}
Click on a data entry to display an image
{javascript_code}
"""
os.makedirs(out_dir, exist_ok=True)
html_path = os.path.join(out_dir, "siglip_clusters.html")
with open(html_path, "w", encoding="utf-8") as f:
f.write(html_template)
return {"plot_html": html_path}
# -------------------- 3. TeamClassifier training --------------------
def train_team_classifier_on_video(video_path: str, stride: int = 30) -> None:
ensure_models_loaded()
frame_generator = sv.get_video_frames_generator(source_path=video_path, stride=stride)
crops = []
for frame in tqdm(frame_generator, desc="collecting crops (TeamClassifier)"):
result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
detections = sv.Detections.from_inference(result)
players_detections = detections[detections.class_id == PLAYER_ID]
players_crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
crops += players_crops
if crops:
TEAM_CLASSIFIER.fit(crops)
# -------------------- 4. goalkeeper team resolution --------------------
def resolve_goalkeepers_team_id(players: sv.Detections, goalkeepers: sv.Detections) -> np.ndarray:
if len(goalkeepers) == 0 or len(players) == 0:
return np.array([])
goalkeepers_xy = goalkeepers.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
players_xy = players.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
team_0_centroid = players_xy[players.class_id == 0].mean(axis=0)
team_1_centroid = players_xy[players.class_id == 1].mean(axis=0)
goalkeepers_team_id = []
for goalkeeper_xy in goalkeepers_xy:
dist_0 = np.linalg.norm(goalkeeper_xy - team_0_centroid)
dist_1 = np.linalg.norm(goalkeeper_xy - team_1_centroid)
goalkeepers_team_id.append(0 if dist_0 < dist_1 else 1)
return np.array(goalkeepers_team_id)
# -------------------- 5. Voronoi blend helper --------------------
def draw_pitch_voronoi_diagram_2(
config: SoccerPitchConfiguration,
team_1_xy: np.ndarray,
team_2_xy: np.ndarray,
team_1_color: sv.Color = sv.Color.RED,
team_2_color: sv.Color = sv.Color.WHITE,
opacity: float = 0.5,
padding: int = 50,
scale: float = 0.1,
pitch: Optional[np.ndarray] = None,
) -> np.ndarray:
if pitch is None:
pitch = draw_pitch(config=config, padding=padding, scale=scale)
scaled_width = int(config.width * scale)
scaled_length = int(config.length * scale)
voronoi = np.zeros_like(pitch, dtype=np.uint8)
team_1_color_bgr = np.array(team_1_color.as_bgr(), dtype=np.uint8)
team_2_color_bgr = np.array(team_2_color.as_bgr(), dtype=np.uint8)
y_coordinates, x_coordinates = np.indices((scaled_width + 2 * padding, scaled_length + 2 * padding))
y_coordinates -= padding
x_coordinates -= padding
def calculate_distances(xy, x_coordinates, y_coordinates):
return np.sqrt(
(xy[:, 0][:, None, None] * scale - x_coordinates) ** 2
+ (xy[:, 1][:, None, None] * scale - y_coordinates) ** 2
)
distances_team_1 = calculate_distances(team_1_xy, x_coordinates, y_coordinates)
distances_team_2 = calculate_distances(team_2_xy, x_coordinates, y_coordinates)
min_distances_team_1 = np.min(distances_team_1, axis=0)
min_distances_team_2 = np.min(distances_team_2, axis=0)
steepness = 15
distance_ratio = min_distances_team_2 / np.clip(
min_distances_team_1 + min_distances_team_2, a_min=1e-5, a_max=None
)
blend_factor = np.tanh((distance_ratio - 0.5) * steepness) * 0.5 + 0.5
for c in range(3):
voronoi[:, :, c] = (
blend_factor * team_1_color_bgr[c] + (1 - blend_factor) * team_2_color_bgr[c]
).astype(np.uint8)
overlay = cv2.addWeighted(voronoi, opacity, pitch, 1 - opacity, 0)
return overlay
# -------------------- 6. single-frame advanced views --------------------
def step_single_frame_advanced(video_path: str, out_dir: str) -> Dict[str, str]:
ensure_models_loaded()
frame_generator = sv.get_video_frames_generator(video_path)
frame = next(frame_generator)
ellipse_annotator = sv.EllipseAnnotator(
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
thickness=2,
)
label_annotator = sv.LabelAnnotator(
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
text_color=sv.Color.from_hex("#000000"),
text_position=sv.Position.BOTTOM_CENTER,
)
triangle_annotator = sv.TriangleAnnotator(
color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
)
tracker = sv.ByteTrack()
tracker.reset()
result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
detections = sv.Detections.from_inference(result)
ball_detections = detections[detections.class_id == BALL_ID]
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
all_detections = detections[detections.class_id != BALL_ID]
all_detections = all_detections.with_nms(threshold=0.5, class_agnostic=True)
all_detections = tracker.update_with_detections(detections=all_detections)
goalkeepers_detections = all_detections[all_detections.class_id == GOALKEEPER_ID]
players_detections = all_detections[all_detections.class_id == PLAYER_ID]
referees_detections = all_detections[all_detections.class_id == REFEREE_ID]
players_crops = [sv.crop_image(frame, xyxy) for xyxy in players_detections.xyxy]
if players_crops:
players_detections.class_id = TEAM_CLASSIFIER.predict(players_crops)
if len(goalkeepers_detections) > 0 and len(players_detections) > 0:
goalkeepers_detections.class_id = resolve_goalkeepers_team_id(
players_detections, goalkeepers_detections
)
referees_detections.class_id -= 1
all_detections2 = sv.Detections.merge(
[players_detections, goalkeepers_detections, referees_detections]
)
labels = [f"#{tid}" for tid in all_detections2.tracker_id]
all_detections2.class_id = all_detections2.class_id.astype(int)
annotated_frame = frame.copy()
annotated_frame = ellipse_annotator.annotate(scene=annotated_frame, detections=all_detections2)
annotated_frame = label_annotator.annotate(
scene=annotated_frame, detections=all_detections2, labels=labels
)
annotated_frame = triangle_annotator.annotate(
scene=annotated_frame, detections=ball_detections
)
os.makedirs(out_dir, exist_ok=True)
annotated_path = os.path.join(out_dir, "frame_advanced.png")
save_image(annotated_path, annotated_frame)
# Pitch + radar + Voronoi
result = FIELD_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
key_points = sv.KeyPoints.from_inference(result)
filt = key_points.confidence[0] > 0.5
frame_reference_points = key_points.xy[0][filt]
pitch_reference_points = np.array(PITCH_CONFIG.vertices)[filt]
transformer = ViewTransformer(source=frame_reference_points, target=pitch_reference_points)
frame_ball_xy = ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
pitch_ball_xy = transformer.transform_points(points=frame_ball_xy)
players_xy = players_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
pitch_players_xy = transformer.transform_points(points=players_xy)
referees_xy = referees_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
pitch_referees_xy = transformer.transform_points(points=referees_xy)
radar = draw_pitch(PITCH_CONFIG)
radar = draw_points_on_pitch(
config=PITCH_CONFIG,
xy=pitch_ball_xy,
face_color=sv.Color.WHITE,
edge_color=sv.Color.BLACK,
radius=10,
pitch=radar,
)
radar = draw_points_on_pitch(
config=PITCH_CONFIG,
xy=pitch_players_xy[players_detections.class_id == 0],
face_color=sv.Color.from_hex("00BFFF"),
edge_color=sv.Color.BLACK,
radius=16,
pitch=radar,
)
radar = draw_points_on_pitch(
config=PITCH_CONFIG,
xy=pitch_players_xy[players_detections.class_id == 1],
face_color=sv.Color.from_hex("FF1493"),
edge_color=sv.Color.BLACK,
radius=16,
pitch=radar,
)
radar = draw_points_on_pitch(
config=PITCH_CONFIG,
xy=pitch_referees_xy,
face_color=sv.Color.from_hex("FFD700"),
edge_color=sv.Color.BLACK,
radius=16,
pitch=radar,
)
radar_path = os.path.join(out_dir, "radar_view.png")
save_image(radar_path, radar)
vor = draw_pitch(PITCH_CONFIG)
vor = draw_pitch_voronoi_diagram(
config=PITCH_CONFIG,
team_1_xy=pitch_players_xy[players_detections.class_id == 0],
team_2_xy=pitch_players_xy[players_detections.class_id == 1],
team_1_color=sv.Color.from_hex("00BFFF"),
team_2_color=sv.Color.from_hex("FF1493"),
pitch=vor,
)
vor_path = os.path.join(out_dir, "voronoi.png")
save_image(vor_path, vor)
blended = draw_pitch(
config=PITCH_CONFIG, background_color=sv.Color.WHITE, line_color=sv.Color.BLACK
)
blended = draw_pitch_voronoi_diagram_2(
config=PITCH_CONFIG,
team_1_xy=pitch_players_xy[players_detections.class_id == 0],
team_2_xy=pitch_players_xy[players_detections.class_id == 1],
team_1_color=sv.Color.from_hex("00BFFF"),
team_2_color=sv.Color.from_hex("FF1493"),
pitch=blended,
)
blended = draw_points_on_pitch(
config=PITCH_CONFIG,
xy=pitch_ball_xy,
face_color=sv.Color.WHITE,
edge_color=sv.Color.WHITE,
radius=8,
thickness=1,
pitch=blended,
)
blended = draw_points_on_pitch(
config=PITCH_CONFIG,
xy=pitch_players_xy[players_detections.class_id == 0],
face_color=sv.Color.from_hex("00BFFF"),
edge_color=sv.Color.WHITE,
radius=16,
thickness=1,
pitch=blended,
)
blended = draw_points_on_pitch(
config=PITCH_CONFIG,
xy=pitch_players_xy[players_detections.class_id == 1],
face_color=sv.Color.from_hex("FF1493"),
edge_color=sv.Color.WHITE,
radius=16,
thickness=1,
pitch=blended,
)
blended_path = os.path.join(out_dir, "voronoi_blended.png")
save_image(blended_path, blended)
return {
"frame_advanced": annotated_path,
"radar": radar_path,
"voronoi": vor_path,
"voronoi_blended": blended_path,
}
# -------------------- 7. ball path & cleaning --------------------
def replace_outliers_based_on_distance(positions: List[np.ndarray], distance_threshold: float) -> List[np.ndarray]:
last_valid_position = None
cleaned_positions: List[np.ndarray] = []
for position in positions:
if len(position) == 0:
cleaned_positions.append(position)
else:
if last_valid_position is None:
cleaned_positions.append(position)
last_valid_position = position
else:
distance = np.linalg.norm(position - last_valid_position)
if distance > distance_threshold:
cleaned_positions.append(np.array([], dtype=np.float64))
else:
cleaned_positions.append(position)
last_valid_position = position
return cleaned_positions
def step_ball_path(video_path: str, out_dir: str) -> Dict[str, Any]:
ensure_models_loaded()
MAXLEN = 5
MAX_DISTANCE_THRESHOLD = 500
video_info = sv.VideoInfo.from_video_path(video_path)
frame_generator = sv.get_video_frames_generator(video_path)
path_raw: List[np.ndarray] = []
M = deque(maxlen=MAXLEN)
for frame in tqdm(frame_generator, total=video_info.total_frames, desc="ball path"):
result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
detections = sv.Detections.from_inference(result)
ball_detections = detections[detections.class_id == BALL_ID]
ball_detections.xyxy = sv.pad_boxes(xyxy=ball_detections.xyxy, px=10)
result = FIELD_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
key_points = sv.KeyPoints.from_inference(result)
filt = key_points.confidence[0] > 0.5
frame_reference_points = key_points.xy[0][filt]
pitch_reference_points = np.array(PITCH_CONFIG.vertices)[filt]
transformer = ViewTransformer(
source=frame_reference_points, target=pitch_reference_points
)
M.append(transformer.m)
transformer.m = np.mean(np.array(M), axis=0)
frame_ball_xy = ball_detections.get_anchors_coordinates(sv.Position.BOTTOM_CENTER)
pitch_ball_xy = transformer.transform_points(points=frame_ball_xy)
path_raw.append(pitch_ball_xy)
path = [
np.empty((0, 2), dtype=np.float32) if coords.shape[0] >= 2 else coords
for coords in path_raw
]
path = [coords.flatten() for coords in path]
path_clean = replace_outliers_based_on_distance(path, MAX_DISTANCE_THRESHOLD)
raw_pitch = draw_pitch(PITCH_CONFIG)
raw_pitch = draw_paths_on_pitch(
config=PITCH_CONFIG, paths=[path], color=sv.Color.WHITE, pitch=raw_pitch
)
raw_path_img = os.path.join(out_dir, "ball_path_raw.png")
save_image(raw_path_img, raw_pitch)
clean_pitch = draw_pitch(PITCH_CONFIG)
clean_pitch = draw_paths_on_pitch(
config=PITCH_CONFIG, paths=[path_clean], color=sv.Color.WHITE, pitch=clean_pitch
)
cleaned_path_img = os.path.join(out_dir, "ball_path_cleaned.png")
save_image(cleaned_path_img, clean_pitch)
coords_clean = [
coords.tolist() if len(coords) > 0 else [] for coords in path_clean
]
return {
"ball_path_raw_img": raw_path_img,
"ball_path_cleaned_img": cleaned_path_img,
"ball_path_cleaned_coords": coords_clean,
}
# -------------------- 8. full-match analysis + event-annotated video --------------------
def step_analyze_and_annotate_video(video_path: str, out_dir: str) -> Dict[str, Any]:
"""
Single pass over the video that:
* tracks players & ball
* computes distance & speed per player (pitch coordinates)
* estimates ball possession per team & per player
* estimates time spent in defensive/middle/attacking thirds
* detects simple events:
- passes (successful between teammates)
- tackles / interceptions (winning ball from opponent)
- clearances
- shots (high-speed ball towards goal)
* renders an annotated MP4 with overlays:
- per-player labels: id, team, speed, distance
- possession HUD per team
- event banners
"""
ensure_models_loaded()
os.makedirs(out_dir, exist_ok=True)
video_info = sv.VideoInfo.from_video_path(video_path)
fps = video_info.fps
dt = 1.0 / max(fps, 1.0)
tracker = sv.ByteTrack()
tracker.reset()
# homography smoothing
Ms = deque(maxlen=5)
# stats
distance_covered_m = defaultdict(float) # tid -> meters
possession_time_player = defaultdict(float) # tid -> seconds
possession_time_team = defaultdict(float) # team_id -> seconds
team_of_player: Dict[int, int] = {} # tid -> team_id
# per-player richer stats for coaches
player_stats: Dict[int, Dict[str, Any]] = defaultdict(
lambda: {
"distance_m": 0.0,
"max_speed_kmh": 0.0,
"time_def_third_s": 0.0,
"time_mid_third_s": 0.0,
"time_att_third_s": 0.0,
"touches": 0,
"successful_passes": 0,
"received_passes": 0,
"shots": 0,
"tackles": 0,
"interceptions": 0,
"clearances": 0,
}
)
events: List[Dict[str, Any]] = []
# last positions for speed / distance (per frame)
prev_positions: Dict[int, np.ndarray] = {}
prev_owner_tid: Optional[int] = None
prev_ball_pos_pitch: Optional[np.ndarray] = None
# simple goal centers in pitch coordinates (x is length, y is width)
goal_centers = {
0: np.array([0.0, PITCH_CONFIG.width / 2.0]),
1: np.array([PITCH_CONFIG.length, PITCH_CONFIG.width / 2.0]),
}
# annotators
ellipse_annotator = sv.EllipseAnnotator(
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
thickness=2,
)
label_annotator = sv.LabelAnnotator(
color=sv.ColorPalette.from_hex(["#00BFFF", "#FF1493", "#FFD700"]),
text_color=sv.Color.from_hex("#000000"),
text_position=sv.Position.BOTTOM_CENTER,
)
triangle_annotator = sv.TriangleAnnotator(
color=sv.Color.from_hex("#FFD700"), base=25, height=21, outline_thickness=1
)
sink_path = os.path.join(out_dir, "annotated_events.mp4")
sink = sv.VideoSink(sink_path, video_info)
# text overlay control
current_event_text = ""
event_text_frames_left = 0
EVENT_TEXT_DURATION_S = 2.0
EVENT_TEXT_DURATION_FRAMES = int(EVENT_TEXT_DURATION_S * fps)
frame_generator = sv.get_video_frames_generator(video_path)
with sink:
for frame_idx, frame in enumerate(
tqdm(frame_generator, total=video_info.total_frames, desc="analyze + annotate")
):
t = frame_idx * dt
# --- detections + tracking ---
det_result = PLAYER_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
detections = sv.Detections.from_inference(det_result)
ball_dets = detections[detections.class_id == BALL_ID]
ball_dets.xyxy = sv.pad_boxes(xyxy=ball_dets.xyxy, px=10)
non_ball = detections[detections.class_id != BALL_ID]
non_ball = non_ball.with_nms(threshold=0.5, class_agnostic=True)
tracked = tracker.update_with_detections(non_ball)
goalkeepers_dets = tracked[tracked.class_id == GOALKEEPER_ID]
players_dets = tracked[tracked.class_id == PLAYER_ID]
referees_dets = tracked[tracked.class_id == REFEREE_ID]
# --- field homography ---
field_result = FIELD_DETECTION_MODEL.infer(frame, confidence=0.3)[0]
key_points = sv.KeyPoints.from_inference(field_result)
filt = key_points.confidence[0] > 0.5
frame_ref = key_points.xy[0][filt]
pitch_ref = np.array(PITCH_CONFIG.vertices)[filt]
if len(frame_ref) < 4:
# Not enough field points: just draw detections and skip advanced stats
annotated = frame.copy()
annotated = ellipse_annotator.annotate(scene=annotated, detections=players_dets)
annotated = triangle_annotator.annotate(scene=annotated, detections=ball_dets)
sink.write_frame(annotated)
continue
transformer = ViewTransformer(source=frame_ref, target=pitch_ref)
Ms.append(transformer.m)
transformer.m = np.mean(np.array(Ms), axis=0)
# --- team classification & pitch positions ---
frame_players_xy_pitch = None
frame_ball_pos_pitch = None
current_positions: Dict[int, np.ndarray] = {}
current_speed_kmh: Dict[int, float] = {}
if len(players_dets) > 0:
crops = [sv.crop_image(frame, xyxy) for xyxy in players_dets.xyxy]
team_preds = TEAM_CLASSIFIER.predict(crops)
players_dets.class_id = team_preds # now class_id = team_id (0/1)
frame_players_xy_img = players_dets.get_anchors_coordinates(
sv.Position.BOTTOM_CENTER
)
frame_players_xy_pitch = transformer.transform_points(
points=frame_players_xy_img
)
pitch_length = PITCH_CONFIG.length
for tid, team_id, pos_pitch in zip(
players_dets.tracker_id, players_dets.class_id, frame_players_xy_pitch
):
tid_int = int(tid)
team_of_player[tid_int] = int(team_id)
current_positions[tid_int] = pos_pitch
prev_pos = prev_positions.get(tid_int)
speed_kmh = 0.0
if prev_pos is not None:
dist_m = float(np.linalg.norm(pos_pitch - prev_pos))
distance_covered_m[tid_int] += dist_m
player_stats[tid_int]["distance_m"] += dist_m
speed_kmh = (dist_m / dt) * 3.6
player_stats[tid_int]["max_speed_kmh"] = max(
player_stats[tid_int]["max_speed_kmh"], speed_kmh
)
current_speed_kmh[tid_int] = speed_kmh
# zone times: defensive / middle / attacking thirds
x_pos = pos_pitch[0]
if x_pos < pitch_length / 3.0:
player_stats[tid_int]["time_def_third_s"] += dt
elif x_pos < 2.0 * pitch_length / 3.0:
player_stats[tid_int]["time_mid_third_s"] += dt
else:
player_stats[tid_int]["time_att_third_s"] += dt
if len(ball_dets) > 0:
frame_ball_xy_img = ball_dets.get_anchors_coordinates(
sv.Position.BOTTOM_CENTER
)
frame_ball_xy_pitch = transformer.transform_points(points=frame_ball_xy_img)
frame_ball_pos_pitch = frame_ball_xy_pitch[0]
# --- possession owner ---
owner_tid: Optional[int] = None
POSSESSION_RADIUS_M = 5.0
if frame_ball_pos_pitch is not None and frame_players_xy_pitch is not None:
dists = np.linalg.norm(frame_players_xy_pitch - frame_ball_pos_pitch, axis=1)
j = int(np.argmin(dists))
if dists[j] < POSSESSION_RADIUS_M:
owner_tid = int(players_dets.tracker_id[j])
# accumulate possession time
if owner_tid is not None:
possession_time_player[owner_tid] += dt
owner_team = team_of_player.get(owner_tid)
if owner_team is not None:
possession_time_team[owner_team] += dt
# --- helper to register events & banner text ---
def register_event(ev: Dict[str, Any], text: str):
nonlocal current_event_text, event_text_frames_left
events.append(ev)
if text:
current_event_text = text
event_text_frames_left = EVENT_TEXT_DURATION_FRAMES
# --- possession change events, passes, tackles, interceptions ---
if owner_tid != prev_owner_tid:
if owner_tid is not None:
player_stats[owner_tid]["touches"] += 1
if owner_tid is not None and prev_owner_tid is not None:
prev_team = team_of_player.get(prev_owner_tid)
cur_team = team_of_player.get(owner_tid)
travel_m = 0.0
if prev_ball_pos_pitch is not None and frame_ball_pos_pitch is not None:
travel_m = float(
np.linalg.norm(frame_ball_pos_pitch - prev_ball_pos_pitch)
)
MIN_PASS_TRAVEL_M = 3.0
if prev_team is not None and cur_team is not None:
if prev_team == cur_team and travel_m > MIN_PASS_TRAVEL_M:
# pass
register_event(
{
"type": "pass",
"t": float(t),
"from_tid": int(prev_owner_tid),
"to_tid": int(owner_tid),
"team_id": int(cur_team),
"extra": {"distance_m": travel_m},
},
f"Pass: #{prev_owner_tid} → #{owner_tid} (Team {cur_team})",
)
player_stats[prev_owner_tid]["successful_passes"] += 1
player_stats[owner_tid]["received_passes"] += 1
elif prev_team != cur_team:
# tackle vs interception
d_pp = 999.0
pos_prev = prev_positions.get(int(prev_owner_tid))
pos_cur = current_positions.get(int(owner_tid))
if pos_prev is not None and pos_cur is not None:
d_pp = float(np.linalg.norm(pos_prev - pos_cur))
ev_type = "tackle" if d_pp < 3.0 else "interception"
label = "Tackle" if ev_type == "tackle" else "Interception"
register_event(
{
"type": ev_type,
"t": float(t),
"from_tid": int(prev_owner_tid),
"to_tid": int(owner_tid),
"team_id": int(cur_team),
"extra": {
"player_distance_m": d_pp,
"ball_travel_m": travel_m,
},
},
f"{label}: #{owner_tid} wins ball from #{prev_owner_tid}",
)
if ev_type == "tackle":
player_stats[owner_tid]["tackles"] += 1
else:
player_stats[owner_tid]["interceptions"] += 1
# generic possession-change event
register_event(
{
"type": "possession_change",
"t": float(t),
"from_tid": int(prev_owner_tid) if prev_owner_tid is not None else None,
"to_tid": int(owner_tid) if owner_tid is not None else None,
"team_id": int(team_of_player.get(owner_tid))
if owner_tid is not None
else None,
"extra": {},
},
"" if owner_tid is None else f"Team {team_of_player.get(owner_tid)} in possession",
)
# --- shot / clearance based on ball speed & direction ---
if (
prev_ball_pos_pitch is not None
and frame_ball_pos_pitch is not None
and owner_tid is not None
):
v = (frame_ball_pos_pitch - prev_ball_pos_pitch) / dt # m/s
speed_mps = float(np.linalg.norm(v))
speed_kmh = speed_mps * 3.6
HIGH_SPEED_KMH = 18.0 # threshold for "hard" actions
if speed_kmh > HIGH_SPEED_KMH:
shooter_team = team_of_player.get(owner_tid)
if shooter_team is not None:
target_goal = goal_centers[1 - shooter_team]
direction = target_goal - frame_ball_pos_pitch
cos_angle = float(
np.dot(v, direction)
/ (np.linalg.norm(v) * np.linalg.norm(direction) + 1e-6)
)
if cos_angle > 0.8:
register_event(
{
"type": "shot",
"t": float(t),
"from_tid": int(owner_tid),
"to_tid": None,
"team_id": int(shooter_team),
"extra": {"speed_kmh": speed_kmh},
},
f"Shot by #{owner_tid} (Team {shooter_team}) – {speed_kmh:.1f} km/h",
)
player_stats[owner_tid]["shots"] += 1
else:
register_event(
{
"type": "clearance",
"t": float(t),
"from_tid": int(owner_tid),
"to_tid": None,
"team_id": int(shooter_team),
"extra": {"speed_kmh": speed_kmh},
},
f"Clearance by #{owner_tid} (Team {shooter_team})",
)
player_stats[owner_tid]["clearances"] += 1
prev_owner_tid = owner_tid
prev_ball_pos_pitch = frame_ball_pos_pitch
prev_positions = current_positions
# --- frame drawing ---
annotated = frame.copy()
# build labels for players: id + team + current speed + total distance
player_labels: List[str] = []
if frame_players_xy_pitch is not None and len(players_dets) > 0:
for tid, pos_pitch in zip(players_dets.tracker_id, frame_players_xy_pitch):
tid_int = int(tid)
team_id = team_of_player.get(tid_int, -1)
speed_kmh = current_speed_kmh.get(tid_int, 0.0)
d_total = distance_covered_m[tid_int]
player_labels.append(
f"#{tid_int} T{team_id} {speed_kmh:4.1f} km/h {d_total:.1f} m"
)
annotated = ellipse_annotator.annotate(
scene=annotated, detections=players_dets
)
annotated = label_annotator.annotate(
scene=annotated, detections=players_dets, labels=player_labels
)
# draw ball
annotated = triangle_annotator.annotate(scene=annotated, detections=ball_dets)
# --- HUD: possession percentages ---
total_poss_time = sum(possession_time_team.values()) + 1e-6
team0_pct = (
100.0 * possession_time_team.get(0, 0.0) / total_poss_time
if total_poss_time > 0
else 0.0
)
team1_pct = (
100.0 * possession_time_team.get(1, 0.0) / total_poss_time
if total_poss_time > 0
else 0.0
)
hud_text = (
f"Team 0 Ball Control: {team0_pct:5.2f}% "
f"Team 1 Ball Control: {team1_pct:5.2f}%"
)
cv2.rectangle(
annotated,
(20, annotated.shape[0] - 60),
(annotated.shape[1] - 20, annotated.shape[0] - 20),
(255, 255, 255),
-1,
)
cv2.putText(
annotated,
hud_text,
(30, annotated.shape[0] - 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.8,
(0, 0, 0),
2,
cv2.LINE_AA,
)
# --- event banner ---
if event_text_frames_left > 0 and current_event_text:
cv2.rectangle(
annotated, (20, 20), (annotated.shape[1] - 20, 90), (255, 255, 255), -1
)
cv2.putText(
annotated,
current_event_text,
(30, 70),
cv2.FONT_HERSHEY_SIMPLEX,
1.0,
(0, 0, 0),
2,
cv2.LINE_AA,
)
event_text_frames_left -= 1
sink.write_frame(annotated)
# finalize stats
total_poss = sum(possession_time_team.values()) + 1e-6
possession_percent_team = {
int(team): 100.0 * t_sec / total_poss for team, t_sec in possession_time_team.items()
}
stats = {
"distance_covered_m": {str(tid): float(d) for tid, d in distance_covered_m.items()},
"possession_time_player_s": {
str(tid): float(t_sec) for tid, t_sec in possession_time_player.items()
},
"possession_time_team_s": {
int(team): float(t_sec) for team, t_sec in possession_time_team.items()
},
"possession_percent_team": possession_percent_team,
"team_of_player": {str(tid): int(team) for tid, team in team_of_player.items()},
"player_stats": {
str(tid): {
k: float(v) if isinstance(v, (int, float)) else v
for k, v in stats_dict.items()
}
for tid, stats_dict in player_stats.items()
},
}
return {
"annotated_video": sink_path,
"stats": stats,
"events": events,
}
# -------------------- 9. full pipeline entrypoint --------------------
def run_full_pipeline(video_path: str, job_dir: str) -> Dict[str, Any]:
"""
Run the full notebook-equivalent pipeline on a video and save all artifacts
into job_dir. Returns paths + stats for the FastAPI app.
"""
set_job_dir(job_dir)
update_progress("initializing", 0.0, "Loading models...")
ensure_models_loaded()
os.makedirs(job_dir, exist_ok=True)
update_progress("siglip", 0.10, "Running SigLIP clustering...")
siglip_out = step_siglip_clustering(video_path, os.path.join(job_dir, "siglip"))
update_progress("team_classifier", 0.25, "Training TeamClassifier...")
train_team_classifier_on_video(video_path)
update_progress("basic_frames", 0.35, "Generating basic annotated frames...")
basic_paths = step_basic_frames(video_path, os.path.join(job_dir, "frames"))
update_progress("advanced_views", 0.45, "Generating advanced radar / Voronoi views...")
adv_paths = step_single_frame_advanced(video_path, os.path.join(job_dir, "advanced"))
update_progress("ball_path", 0.60, "Computing ball path and heatmap...")
ball_paths = step_ball_path(video_path, os.path.join(job_dir, "ball_path"))
update_progress(
"events_video",
0.80,
"Analyzing match, computing speed/distance, and rendering event-annotated video...",
)
analysis_out = step_analyze_and_annotate_video(
video_path, os.path.join(job_dir, "analysis")
)
result = {
"basic": basic_paths,
"advanced": adv_paths,
"ball": ball_paths,
"stats": analysis_out["stats"],
"events": analysis_out["events"],
"annotated_video": analysis_out["annotated_video"],
"siglip_html": siglip_out["plot_html"],
}
# Save a copy for the UI result page
result_path = os.path.join(job_dir, "result.json")
with open(result_path, "w", encoding="utf-8") as f:
json.dump(result, f)
update_progress("done", 1.0, "Completed")
return result