diff --git a/README.md b/README.md
index 43b8c780..f57573d4 100644
--- a/README.md
+++ b/README.md
@@ -17,9 +17,11 @@ The **unravelsports** package aims to aid researchers, analysts and enthusiasts
This package currently supports:
- β½ π [**Polars DataFrame Conversion**](#polars-dataframes)
- β½ π [**Graph Neural Network**](#graph-neural-networks) Training, Graph Conversion and Prediction
- [[Bekkers & Sahasrabudhe (2023)](https://arxiv.org/pdf/2411.17450)]
+ [[π Bekkers & Sahasrabudhe (2023)](https://arxiv.org/pdf/2411.17450)]
- β½ [**Pressing Intensity**](#pressing-intensity)
- [[Bekkers (2024)](https://arxiv.org/pdf/2501.04712)]
+ [[π Bekkers (2024)](https://arxiv.org/pdf/2501.04712)]
+- β½ [**Formation and Position Identification (EFPI)**](#pressing-intensity)
+ [[π Bekkers (2025)](https://arxiv.org/pdf/2506.23843)]
π Features
-----
@@ -28,7 +30,7 @@ This package currently supports:
β½π **Convert Tracking Data** into [Polars DataFrames](https://pola.rs/) for rapid data conversion and data processing.
-β½ For soccer we rely on [Kloppy](https://kloppy.pysport.org/) and as such we support _Sportec_$^1$, _SkillCorner_$^1$, _PFF_$^{1, 2}$, _Metrica_$^1$, _StatsPerform_, _Tracab (CyronHego)_ and _SecondSpectrum_ tracking data.
+β½ For soccer we rely on [Kloppy](https://kloppy.pysport.org/) and as such we support Sportec, SkillCorner, PFF / GradientSports, Metrica, StatsPerform, Tracab (CyronHego), SecondSpectrum, HawkEye and Signality tracking data.
```python
from unravel.soccer import KloppyPolarsDataset
@@ -48,9 +50,6 @@ kloppy_polars_dataset = KloppyPolarsDataset(
| 4 | 1 | 0 days 00:00:00 | 10000 | alive | DFL-OBJ-0001HW | -46.26 | 0.08 | 0 | DFL-CLU-000005 | GK | DFL-MAT-J03WPY | 0.357 | 0.071 | 0 | 0.364 | 0 | 0 | 0 | 0 | DFL-CLU-00000P | False |
-$^1$ Open data available through kloppy.
-
-$^2$ Currently unreleased in kloppy, only available through kloppy master branch. [Click here for World Cup 2022 Dataset](https://www.blog.fc.pff.com/blog/enhanced-2022-world-cup-dataset)
π For American Football we use [BigDataBowl Data](https://www.kaggle.com/competitions/nfl-big-data-bowl-2025/data) directly.
@@ -86,6 +85,8 @@ converter = SoccerGraphConverter(
)
```
+---
+
### **Pressing Intensity**
Compute [**Pressing Intensity**](https://arxiv.org/abs/2501.04712) for a whole game (or segment) of Soccer tracking data.
@@ -113,6 +114,29 @@ model.fit(

+---
+
+### **Formation and Position Identification**
+
+Compute [Elastic Formation and Position Identification, **EFPI**](https://arxiv.org/pdf/2506.23843) for individual frames, possessions, periods or specific time intervals for Soccer.
+
+For more information on all possibilities for "every" check out [Polars Documentation](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.group_by_dynamic.html).
+
+```python
+from unravel.soccer import EFPI
+
+model = EFPI(dataset=kloppy_polars_dataset)
+model.fit(
+ # Default 65 formations , or specify a subset (e.g. ["442" , "433"])
+ formations=None,
+ # specific time intervals (e.g. 1m, 1m14s, 2m30s etc.), or specify "possession", "period" or "frame".
+ every="5m",
+ substitutions="drop",
+ change_threshold=0.1,
+ change_after_possession=True,
+)
+```
+
β ***More to come soon...!***
π Quick Start
diff --git a/tests/test_soccer.py b/tests/test_soccer.py
index 491882ab..9bc64257 100644
--- a/tests/test_soccer.py
+++ b/tests/test_soccer.py
@@ -3,6 +3,7 @@
SoccerGraphConverter,
KloppyPolarsDataset,
PressingIntensity,
+ EFPI,
Constant,
Column,
Group,
@@ -1111,3 +1112,272 @@ def test_plot_error_wrong_extension_for_mp4(
end_timestamp=pl.duration(seconds=11, milliseconds=900),
period_id=1,
)
+
+ def test_efpi_frame_drop_0_true(
+ self, kloppy_polars_sportec_dataset: KloppyPolarsDataset
+ ):
+ model = EFPI(
+ dataset=kloppy_polars_sportec_dataset,
+ )
+
+ model = model.fit(
+ formations=None,
+ every="frame",
+ substitutions="drop",
+ change_threshold=0.0,
+ change_after_possession=True,
+ )
+
+ single_frame = model.output.filter(pl.col(Column.FRAME_ID) == 10018)
+
+ assert model.segments == None
+ assert model.output.columns == [
+ Column.GAME_ID,
+ Column.PERIOD_ID,
+ Column.FRAME_ID,
+ Column.OBJECT_ID,
+ Column.TEAM_ID,
+ "position",
+ "formation",
+ Column.BALL_OWNING_TEAM_ID,
+ "is_attacking",
+ ]
+ assert len(model.output) == 483
+ assert (
+ single_frame.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-00008F")[
+ "position"
+ ][0]
+ == "CB"
+ )
+ assert (
+ single_frame.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-00008F")[
+ "formation"
+ ][0]
+ == "3232"
+ )
+ assert (
+ single_frame.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-00008F")[
+ "is_attacking"
+ ][0]
+ == False
+ )
+ assert (
+ single_frame.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-002FXT")[
+ "position"
+ ][0]
+ == "LW"
+ )
+ assert (
+ single_frame.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-002FXT")[
+ "formation"
+ ][0]
+ == "31222"
+ )
+ assert (
+ single_frame.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-002FXT")[
+ "is_attacking"
+ ][0]
+ == True
+ )
+
+ assert (
+ single_frame.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-0001HW")[
+ "position"
+ ][0]
+ == "GK"
+ )
+ assert (
+ single_frame.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-0028FW")[
+ "position"
+ ][0]
+ == "GK"
+ )
+
+ def test_efpi_possession_drop_sg(
+ self, kloppy_polars_sportec_dataset: KloppyPolarsDataset
+ ):
+ model = EFPI(
+ dataset=kloppy_polars_sportec_dataset,
+ )
+
+ model = model.fit(
+ formations="shaw-glickman",
+ every="possession",
+ substitutions="drop",
+ change_threshold=0.1,
+ change_after_possession=True,
+ )
+
+ assert isinstance(model.segments, pl.DataFrame)
+ assert len(model.segments) == 1
+ assert model.segments.columns == [
+ "possession_id",
+ "n_frames",
+ "start_timestamp",
+ "end_timestamp",
+ "start_frame_id",
+ "end_frame_id",
+ ]
+ assert model.output.columns == [
+ Column.GAME_ID,
+ Column.PERIOD_ID,
+ Column.BALL_OWNING_TEAM_ID,
+ "possession_id",
+ Column.OBJECT_ID,
+ Column.TEAM_ID,
+ "position",
+ "formation",
+ "is_attacking",
+ ]
+ assert len(model.output) == 23
+
+ single_possession = model.output.filter(pl.col("possession_id") == 1)
+ assert (
+ single_possession.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-00008F")[
+ "position"
+ ][0]
+ == "CB"
+ )
+ assert (
+ single_possession.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-00008F")[
+ "formation"
+ ][0]
+ == "3232"
+ )
+ assert (
+ single_possession.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-00008F")[
+ "is_attacking"
+ ][0]
+ == False
+ )
+ assert (
+ single_possession.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-002FXT")[
+ "position"
+ ][0]
+ == "LW"
+ )
+ assert (
+ single_possession.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-002FXT")[
+ "formation"
+ ][0]
+ == "3241"
+ )
+ assert (
+ single_possession.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-002FXT")[
+ "is_attacking"
+ ][0]
+ == True
+ )
+
+ assert (
+ single_possession.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-0001HW")[
+ "position"
+ ][0]
+ == "GK"
+ )
+ assert (
+ single_possession.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-0028FW")[
+ "position"
+ ][0]
+ == "GK"
+ )
+
+ def test_efpi_period_442(self, kloppy_polars_sportec_dataset: KloppyPolarsDataset):
+ model = EFPI(
+ dataset=kloppy_polars_sportec_dataset,
+ )
+
+ model = model.fit(
+ formations=["442"],
+ every="period",
+ substitutions="drop",
+ change_threshold=0.1,
+ change_after_possession=True,
+ )
+
+ assert isinstance(model.segments, pl.DataFrame)
+ assert len(model.segments) == 1
+ assert model.segments.columns == [
+ "period_id",
+ "n_frames",
+ "start_timestamp",
+ "end_timestamp",
+ "start_frame_id",
+ "end_frame_id",
+ ]
+ assert model.output.columns == [
+ Column.GAME_ID,
+ Column.PERIOD_ID,
+ Column.BALL_OWNING_TEAM_ID,
+ Column.OBJECT_ID,
+ Column.TEAM_ID,
+ "position",
+ "formation",
+ "is_attacking",
+ ]
+ assert len(model.output) == 23
+
+ single_period = model.output.filter(pl.col("period_id") == 1)
+ assert (
+ single_period.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-00008F")[
+ "position"
+ ][0]
+ == "RCB"
+ )
+ assert (
+ single_period.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-00008F")[
+ "formation"
+ ][0]
+ == "442"
+ )
+ assert (
+ single_period.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-00008F")[
+ "is_attacking"
+ ][0]
+ == False
+ )
+ assert (
+ single_period.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-002FXT")[
+ "position"
+ ][0]
+ == "LM"
+ )
+ assert (
+ single_period.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-002FXT")[
+ "formation"
+ ][0]
+ == "442"
+ )
+ assert (
+ single_period.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-002FXT")[
+ "is_attacking"
+ ][0]
+ == True
+ )
+
+ assert (
+ single_period.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-0001HW")[
+ "position"
+ ][0]
+ == "GK"
+ )
+ assert (
+ single_period.filter(pl.col(Column.OBJECT_ID) == "DFL-OBJ-0028FW")[
+ "position"
+ ][0]
+ == "GK"
+ )
+
+ def test_efpi_wrong(self, kloppy_polars_sportec_dataset):
+ import pytest
+ from polars.exceptions import PanicException
+
+ with pytest.raises(PanicException):
+ model = EFPI(dataset=kloppy_polars_sportec_dataset)
+ model.fit(
+ formations=["442"],
+ every="5mm",
+ substitutions="drop",
+ change_threshold=0.1,
+ change_after_possession=True,
+ )
diff --git a/unravel/soccer/dataset/kloppy_polars.py b/unravel/soccer/dataset/kloppy_polars.py
index 92e9aa98..7e7d5e6c 100644
--- a/unravel/soccer/dataset/kloppy_polars.py
+++ b/unravel/soccer/dataset/kloppy_polars.py
@@ -655,6 +655,7 @@ def __apply_settings(
max_player_acceleration=self._max_player_acceleration,
max_ball_acceleration=self._max_ball_acceleration,
ball_carrier_threshold=self._ball_carrier_threshold,
+ frame_rate=self.kloppy_dataset.metadata.frame_rate,
)
def load(
diff --git a/unravel/soccer/models/__init__.py b/unravel/soccer/models/__init__.py
index 83ebcd77..ddd4ae6f 100644
--- a/unravel/soccer/models/__init__.py
+++ b/unravel/soccer/models/__init__.py
@@ -1,2 +1,3 @@
from .pressing_intensity import *
+from .formations.efpi import EFPI
from .utils import *
diff --git a/unravel/soccer/models/formations/__init__.py b/unravel/soccer/models/formations/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/unravel/soccer/models/formations/detection.py b/unravel/soccer/models/formations/detection.py
new file mode 100644
index 00000000..5219e394
--- /dev/null
+++ b/unravel/soccer/models/formations/detection.py
@@ -0,0 +1,235 @@
+import numpy as np
+
+from kloppy.domain import Orientation
+
+import numpy as np
+import polars as pl
+
+from dataclasses import dataclass, field
+
+from typing import List, Dict
+
+from ...dataset.kloppy_polars import (
+ KloppyPolarsDataset,
+ Column,
+ Constant,
+)
+
+
+@dataclass
+class DetectedFormation:
+ is_attack: bool
+ formation_name: str = None
+ cost: float = None
+ labels: np.ndarray = field(default_factory=np.ndarray)
+ ids: np.ndarray = field(default_factory=np.ndarray)
+
+ def __post_init__(self):
+ self.n_outfield_players = len(self.labels[self.labels != "GK"])
+ self.labels_dict = dict(zip(self.ids, self.labels))
+
+ def update(
+ self,
+ is_attack: bool,
+ formation_name: str,
+ cost: float,
+ labels: np.ndarray = None,
+ ids: np.ndarray = None,
+ ):
+ self.is_attack = is_attack
+ self.formation_name = formation_name
+ self.cost = cost
+
+ for object_id, label in zip(ids, labels):
+ self.labels_dict[object_id] = label
+
+
+@dataclass
+class FormationDetection:
+ dataset: KloppyPolarsDataset
+ chunk_size: int = field(init=True, repr=False, default=2_000)
+
+ def __post_init__(self):
+ if not isinstance(self.dataset, KloppyPolarsDataset):
+ raise ValueError("dataset should be of type KloppyPolarsDataset...")
+
+ if not self.dataset.settings.orientation == Orientation.BALL_OWNING_TEAM:
+ raise ValueError(
+ "KloppyPolarsDataset orientation should be Orientation.BALL_OWNING_TEAM..."
+ )
+
+ self.settings = self.dataset.settings
+ self.dataset = self.dataset.data
+
+ def __repr__(self):
+ n_frames = (
+ self.output[Column.FRAME_ID].n_unique() if hasattr(self, "output") else None
+ )
+ window_size = self._window_size if self._window_size is not None else 1
+ return f"FormationDetection(n_frames={n_frames}, window_size={window_size})"
+
+ @property
+ def _exprs_variables(self):
+ return [
+ Column.X,
+ Column.Y,
+ Column.TEAM_ID,
+ Column.BALL_OWNING_TEAM_ID,
+ Column.OBJECT_ID,
+ Column.POSITION_NAME,
+ ]
+
+ def __compute(self, args: List[pl.Series]) -> dict:
+ raise NotImplementedError()
+
+ def fit(
+ self,
+ ):
+ raise NotImplementedError()
+
+
+@dataclass
+class Formations:
+ pitch_length: float
+ pitch_width: float
+ formations: List[str] = None
+ detected_formations: Dict[str, DetectedFormation] = field(init=False, repr=False)
+
+ def __post_init__(self):
+ self.detected_formations = dict()
+ self._pitch()
+ self.get_formations()
+
+ def set_detected_formation(
+ self,
+ team_id: str,
+ is_attack: bool,
+ name: str,
+ cost: float,
+ labels: np.ndarray = None,
+ ids: np.ndarray = None,
+ ):
+ if self.detected_formations.get(team_id, None) is None:
+ self.detected_formations[team_id] = DetectedFormation(
+ is_attack=is_attack,
+ formation_name=name,
+ cost=cost,
+ labels=labels,
+ ids=ids,
+ )
+ else:
+ self.detected_formations[team_id].update(
+ is_attack=is_attack,
+ formation_name=name,
+ cost=cost,
+ labels=labels,
+ ids=ids,
+ )
+
+ def get_detected_formations_as_dict(self, object_ids: list, team_ids: list):
+ positions, formations = [], []
+
+ for object_id, team_id in zip(object_ids, team_ids):
+
+ if object_id == Constant.BALL:
+ positions.append(Constant.BALL)
+ formations.append(Constant.BALL)
+ continue
+
+ team_formation = self.detected_formations[team_id]
+ positions.append(team_formation.labels_dict[object_id])
+ formations.append(team_formation.formation_name)
+
+ return {
+ Column.OBJECT_ID: object_ids,
+ Column.TEAM_ID: team_ids,
+ "position": positions,
+ "formation": formations,
+ }
+
+ def get_options(self):
+ if self.formations is None:
+ return [x for x in self.pitch.formations if not x.isalpha()]
+ elif self.formations == "shaw-glickman":
+ return [
+ "5221",
+ "352",
+ "343flat",
+ "3232",
+ "4222",
+ "41212",
+ "343",
+ "41221",
+ "433",
+ "4321",
+ "4141",
+ "442",
+ "3331",
+ "31312",
+ "3241",
+ "3142",
+ "2422",
+ "2332",
+ "2431",
+ ]
+ else:
+ return self.formations
+
+ def _pitch(self):
+ try:
+ from mplsoccer import Pitch
+ except ImportError:
+ raise ImportError(
+ "Seems like you don't have mplsoccer installed. Please"
+ " install it using: pip install mplsoccer"
+ )
+ self.pitch = Pitch(
+ pitch_type="secondspectrum",
+ pitch_length=self.pitch_length,
+ pitch_width=self.pitch_width,
+ )
+
+ def get_positions(self, formation: str):
+ if formation not in self.pitch.formations:
+ raise ValueError(f"Formation {formation} is not available.")
+ return self.pitch.get_formation(formation)
+
+ def get_formation_positions_left_to_right(self):
+ return self._formations_coords_ltr
+
+ def get_formation_positions_right_to_left(self):
+ return self._formations_coords_rtl
+
+ def get_formation_labels_left_to_right(self):
+ return self._formations_labels_ltr
+
+ def get_formation_labels_right_to_left(self):
+ return self._formations_labels_rtl
+
+ def get_formations(self):
+ self._formations_coords_ltr = {k: dict() for k in [8, 9, 10]}
+ self._formations_coords_rtl = {k: dict() for k in [8, 9, 10]}
+ self._formations_labels_ltr = {k: dict() for k in [8, 9, 10]}
+ self._formations_labels_rtl = {k: dict() for k in [8, 9, 10]}
+
+ for formation in self.get_options():
+ positions = self.get_positions(formation)
+
+ f = [
+ {
+ k: v
+ for k, v in pos.__dict__.items()
+ if not k in ["location", "statsbomb", "wyscout", "opta"]
+ }
+ for pos in positions
+ if pos.name != "GK"
+ ]
+ labels = np.asarray([pos.name for pos in positions if pos.name != "GK"])
+ self._formations_coords_ltr[len(f)][formation] = np.array(
+ [(v["x"], v["y"]) for v in f]
+ )
+ self._formations_coords_rtl[len(f)][formation] = np.array(
+ [(v["x_flip"], v["y_flip"]) for v in f]
+ )
+ self._formations_labels_ltr[len(f)][formation] = labels
+ self._formations_labels_rtl[len(f)][formation] = labels
diff --git a/unravel/soccer/models/formations/efpi.py b/unravel/soccer/models/formations/efpi.py
new file mode 100644
index 00000000..eb02fed4
--- /dev/null
+++ b/unravel/soccer/models/formations/efpi.py
@@ -0,0 +1,512 @@
+from dataclasses import dataclass
+
+import numpy as np
+
+from typing import Literal, List, Union, Literal, Optional
+
+from kloppy.domain import AttackingDirection, Orientation
+
+from .detection import FormationDetection, Formations
+
+import polars as pl
+
+from ...dataset.kloppy_polars import (
+ Group,
+ Column,
+ Constant,
+)
+
+
+@dataclass
+class EFPI(FormationDetection):
+ _fit = False
+
+ def __post_init__(self):
+ super().__post_init__()
+ self.__get_linear_sum_assignment()
+
+ def __get_linear_sum_assignment(self):
+ try:
+ from scipy.optimize import linear_sum_assignment
+ except ImportError:
+ raise ImportError(
+ "Seems like you don't have scipy installed. Please"
+ " install it using: pip install scipy"
+ )
+ self.linear_sum_assignment = linear_sum_assignment
+
+ def __repr__(self):
+ if not self._fit:
+ return f"EFPI(n_frames={len(self.dataset)}, formations={self.formations if self._formations is not None else 'mplsoccer'})"
+ else:
+ return f"EFPI(n_frames={len(self.dataset)}, formations={self.formations if self._formations is not None else 'mplsoccer'}, every={self._every}, substitutions={self._substitutions}, change_after_possession={self._change_after_possession}, change_threshold={self._change_threshold})"
+
+ @staticmethod
+ def __scale_all_to_bounds(points, min_x, min_y, max_x, max_y):
+ global_min = points.min(axis=(0, 1))
+ global_max = points.max(axis=(0, 1))
+
+ scale = np.where(
+ global_max - global_min != 0,
+ (max_x - min_x, max_y - min_y) / (global_max - global_min),
+ 1,
+ )
+
+ # Apply transformation
+ scaled_points = (points - global_min) * scale + np.array([min_x, min_y])
+
+ return scaled_points
+
+ def __assign_formation(
+ self, coordinates: np.ndarray, direction: AttackingDirection
+ ):
+ if direction == AttackingDirection.LTR:
+ relevant_formations = self._forms.get_formation_positions_left_to_right()
+ relevant_position_labels = self._forms.get_formation_labels_left_to_right()
+ elif direction == AttackingDirection.RTL:
+ relevant_formations = self._forms.get_formation_positions_right_to_left()
+ relevant_position_labels = self._forms.get_formation_labels_right_to_left()
+ else:
+ raise ValueError("AttackingDirection is not set...")
+
+ numb_players = len(coordinates)
+
+ min_x, max_x = np.min(coordinates[:, 0]), np.max(coordinates[:, 0])
+ min_y, max_y = np.min(coordinates[:, 1]), np.max(coordinates[:, 1])
+
+ _form = np.asarray([v for k, v in relevant_formations[numb_players].items()])
+ _form = self.__scale_all_to_bounds(
+ points=_form, min_x=min_x, min_y=min_y, max_x=max_x, max_y=max_y
+ )
+ forms = [k for k in relevant_formations[numb_players]]
+
+ cost_matrices = np.linalg.norm(
+ coordinates[:, np.newaxis, np.newaxis, :] - _form[np.newaxis, :, :, :],
+ axis=-1,
+ )
+
+ costs = np.array(
+ [
+ cost_matrices[:, i, :][
+ self.linear_sum_assignment(cost_matrices[:, i, :])
+ ].sum()
+ for i in range(len(_form))
+ ]
+ )
+
+ idx = np.argmin(costs)
+ selected_formation_cost = np.min(costs)
+
+ cheapest_matrix = cost_matrices[:, idx, :]
+ row_ind, col_ind = self.linear_sum_assignment(cheapest_matrix)
+
+ selected_formation = forms[idx]
+ selected_coords = relevant_formations[numb_players][selected_formation]
+ players = relevant_position_labels[numb_players][selected_formation][row_ind][
+ col_ind
+ ]
+ return (
+ players,
+ selected_coords,
+ coordinates,
+ selected_formation,
+ selected_formation_cost,
+ )
+
+ def __is_update(self, team_id, formation_cost, object_ids, is_attack):
+ if self._forms.detected_formations.get(team_id) is None:
+ return True
+ else:
+ if self._change_threshold is None:
+ return True
+ elif set(self._forms.detected_formations[team_id].ids) != set(object_ids):
+ # update if we encounter a different set of player ids frame to frame
+ return True
+ elif (self._change_after_possession) & (
+ self._forms.detected_formations[team_id].is_attack != is_attack
+ ):
+ # update if we switch from attack to defense
+ return True
+ elif (
+ self._forms.detected_formations[team_id].cost - formation_cost
+ ) / formation_cost > self._change_threshold:
+ # update if we passed the threshold
+ return True
+ else:
+ return False
+
+ def __detect(self, is_attack, direction, d):
+ xs, ys = d[Column.X], d[Column.Y]
+ if is_attack:
+ team_idx = np.where(
+ (d[Column.TEAM_ID] == d[Column.BALL_OWNING_TEAM_ID])
+ & (d[Column.POSITION_NAME] != "GK")
+ )[0]
+ gk_idx = np.where(
+ (d[Column.TEAM_ID] == d[Column.BALL_OWNING_TEAM_ID])
+ & (d[Column.POSITION_NAME] == "GK")
+ )[0]
+ team_id = d[Column.BALL_OWNING_TEAM_ID][0]
+ else:
+ team_idx = np.where(
+ (d[Column.TEAM_ID] != d[Column.BALL_OWNING_TEAM_ID])
+ & (d[Column.POSITION_NAME] != "GK")
+ & (d[Column.TEAM_ID] != Constant.BALL)
+ )[0]
+ gk_idx = np.where(
+ (d[Column.TEAM_ID] != d[Column.BALL_OWNING_TEAM_ID])
+ & (d[Column.TEAM_ID] != Constant.BALL)
+ & (d[Column.POSITION_NAME] == "GK")
+ )[0]
+ team_id = d[Column.TEAM_ID][
+ (d[Column.TEAM_ID] != d[Column.BALL_OWNING_TEAM_ID])
+ & (d[Column.TEAM_ID] != Constant.BALL)
+ ][0]
+
+ outfield_coordinates = np.stack((xs[team_idx], ys[team_idx]), axis=-1)
+
+ position_labels, _, _, formation, formation_cost = self.__assign_formation(
+ coordinates=outfield_coordinates, direction=direction
+ )
+
+ _idxs = np.concatenate((team_idx, gk_idx))
+ labels = np.concatenate((position_labels, ["GK"]))
+ object_ids = d[Column.OBJECT_ID][_idxs]
+
+ if self.__is_update(team_id, formation_cost, object_ids, is_attack):
+ self._forms.set_detected_formation(
+ team_id=team_id,
+ is_attack=is_attack,
+ name=formation,
+ cost=formation_cost,
+ labels=labels,
+ ids=object_ids,
+ )
+
+ def _compute(self, args: List[pl.Series], **kwargs) -> pl.DataFrame:
+ d = {col: args[i].to_numpy() for i, col in enumerate(self._exprs_variables)}
+
+ d.update(kwargs)
+
+ attacking_team_id = d[Column.BALL_OWNING_TEAM_ID][0]
+ attacking_direction = (
+ AttackingDirection.LTR
+ if self.settings.orientation == Orientation.BALL_OWNING_TEAM
+ else (
+ AttackingDirection.LTR
+ if self.settings.orientation == Orientation.STATIC_HOME_AWAY
+ and attacking_team_id == self.settings.home_team_id
+ else AttackingDirection.RTL
+ )
+ )
+ defending_direction = (
+ AttackingDirection.RTL
+ if attacking_direction == AttackingDirection.LTR
+ else AttackingDirection.LTR
+ )
+
+ self.__detect(
+ is_attack=True,
+ direction=attacking_direction,
+ d=d,
+ )
+ self.__detect(
+ is_attack=False,
+ direction=defending_direction,
+ d=d,
+ )
+
+ return self._forms.get_detected_formations_as_dict(
+ object_ids=d[Column.OBJECT_ID].tolist(), team_ids=d[Column.TEAM_ID].tolist()
+ )
+
+ def fit(
+ self,
+ start_time: pl.duration = None,
+ end_time: pl.duration = None,
+ period_id: int = None,
+ every: Optional[
+ Union[str, Literal["frame"], Literal["period"], Literal["possession"]]
+ ] = "frame",
+ formations: Union[List[str], Literal["shaw-glickman"]] = None,
+ substitutions: Literal["merge", "drop"] = "drop",
+ change_after_possession: bool = True,
+ change_threshold: float = None,
+ ):
+ """
+ - Count number of players seen
+ - update_threshold: float: value between 0 and 1 indicating the minimum change in formation assignment cost to update the detected formation.
+ """
+ self._substitutions = substitutions
+ self._change_threshold = change_threshold
+ self._change_after_possession = change_after_possession
+ self._every = every
+ self._formations = formations
+
+ __added_arbitrary_base = False
+
+ self._forms = Formations(
+ pitch_length=self.settings.pitch_dimensions.pitch_length,
+ pitch_width=self.settings.pitch_dimensions.pitch_width,
+ formations=self._formations,
+ )
+
+ if all(x is None for x in [start_time, end_time, period_id]):
+ df = self.dataset
+ elif all(x is not None for x in [start_time, end_time, period_id]):
+ df = self.dataset.filter(
+ (pl.col(Column.TIMESTAMP).is_between(start_time, end_time))
+ & (pl.col(Column.PERIOD_ID) == period_id)
+ )
+ else:
+ raise ValueError(
+ "Please specificy all of start_time, end_time and period_id or none of them..."
+ )
+
+ if self._every == "frame":
+ group_by_columns = Group.BY_FRAME
+
+ self.output = (
+ (
+ df.sort([Column.FRAME_ID, Column.OBJECT_ID])
+ .group_by(group_by_columns, maintain_order=True)
+ .agg(
+ pl.map_groups(
+ exprs=self._exprs_variables,
+ function=lambda group: self._compute(group),
+ return_dtype=pl.Struct,
+ ).alias("result")
+ )
+ .unnest("result")
+ )
+ .explode([Column.OBJECT_ID, Column.TEAM_ID, "position", "formation"])
+ .join(
+ df.select([Column.FRAME_ID, Column.BALL_OWNING_TEAM_ID]).unique(
+ [Column.FRAME_ID, Column.BALL_OWNING_TEAM_ID]
+ ),
+ on=Column.FRAME_ID,
+ how="left",
+ )
+ .with_columns(
+ pl.when((pl.col(Column.OBJECT_ID) == Constant.BALL))
+ .then(None)
+ .when(
+ (pl.col(Column.TEAM_ID) == pl.col(Column.BALL_OWNING_TEAM_ID))
+ )
+ .then(True)
+ .otherwise(False)
+ .alias("is_attacking")
+ )
+ .sort([Column.FRAME_ID, "is_attacking", Column.OBJECT_ID])
+ )
+ self.segments = None
+ self._fit = True
+ return self
+
+ elif isinstance(self._every, str):
+ group_by_columns = [
+ Column.GAME_ID,
+ Column.PERIOD_ID,
+ Column.BALL_OWNING_TEAM_ID,
+ Column.OBJECT_ID,
+ ]
+ segment_id = f"{self._every}_id"
+
+ df = df.with_columns(
+ [
+ (
+ pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID)
+ ).alias("is_attacking")
+ ]
+ )
+ group_by_columns.append("is_attacking")
+
+ if self._every == "possession":
+ df1 = df.sort(Column.FRAME_ID).with_columns(
+ [
+ (
+ (
+ pl.col(Column.BALL_OWNING_TEAM_ID)
+ != pl.col(Column.BALL_OWNING_TEAM_ID).shift(1)
+ )
+ | (
+ pl.col(Column.PERIOD_ID)
+ != pl.col(Column.PERIOD_ID).shift(1)
+ )
+ )
+ .fill_null(True)
+ .cast(pl.Int32)
+ .cum_sum()
+ .alias(segment_id)
+ ]
+ )
+ elif self._every == "period":
+ df1 = df.sort("frame_id")
+
+ elif isinstance(self._every, str):
+ from datetime import datetime
+
+ base_time = datetime(2000, 1, 1)
+ __added_arbitrary_base = True
+
+ df1 = df.sort(Column.FRAME_ID).with_columns(
+ (pl.lit(base_time) + pl.col(Column.TIMESTAMP))
+ .dt.truncate(self._every)
+ .alias(segment_id)
+ )
+
+ # Any moment we have more than 11 players we have overlapping substitutions in a segment
+ overlapping_substitutions = (
+ df1.filter(
+ (pl.col(Column.TEAM_ID) != Constant.BALL)
+ & (pl.col(Column.POSITION_NAME) != "GK")
+ )
+ .group_by(
+ [Column.GAME_ID, Column.PERIOD_ID, Column.TEAM_ID, segment_id]
+ if self._every != "period"
+ else [Column.GAME_ID, Column.PERIOD_ID, Column.TEAM_ID]
+ )
+ .agg([pl.col(Column.OBJECT_ID).n_unique().alias("objects")])
+ .sort([segment_id])
+ .filter(pl.col("objects") > 10)
+ )
+
+ if not overlapping_substitutions.is_empty():
+ if self._substitutions == "drop":
+ columns = [
+ Column.GAME_ID,
+ Column.PERIOD_ID,
+ Column.TEAM_ID,
+ Column.OBJECT_ID,
+ segment_id,
+ ]
+ player_segments_to_drop = (
+ df1.join(
+ overlapping_substitutions,
+ how="inner",
+ on=[
+ Column.GAME_ID,
+ Column.PERIOD_ID,
+ Column.TEAM_ID,
+ segment_id,
+ ],
+ )
+ .group_by(columns)
+ .agg([pl.len().alias("length")])
+ .with_columns(
+ pl.col("length")
+ .rank(method="ordinal", descending=True)
+ .over(
+ [
+ Column.GAME_ID,
+ Column.PERIOD_ID,
+ Column.TEAM_ID,
+ segment_id,
+ ]
+ )
+ .alias("rank")
+ )
+ .filter(pl.col("rank") > 11)
+ .drop("rank")
+ .select(columns)
+ )
+ df1 = df1.join(player_segments_to_drop, on=columns, how="anti")
+ elif self._substitutions == "merge":
+ raise NotImplementedError(
+ "Merging overlapping substitutions within a window is not implemented yet..."
+ )
+ else:
+ raise ValueError(
+ "'substitutions' should either be 'merge' or 'drop'..."
+ )
+
+ segment_coordinates = (
+ df1.group_by(
+ group_by_columns + [segment_id]
+ if self._every != "period"
+ else group_by_columns
+ )
+ .agg(
+ [
+ pl.col(Column.X).mean().alias(Column.X),
+ pl.col(Column.Y).mean().alias(Column.Y),
+ pl.col(Column.POSITION_NAME)
+ .first()
+ .alias(Column.POSITION_NAME),
+ pl.col(Column.TEAM_ID).first().alias(Column.TEAM_ID),
+ pl.col(Column.FRAME_ID).unique().len().alias("n_frames"),
+ pl.col(Column.TIMESTAMP).min().alias("start_timestamp"),
+ pl.col(Column.TIMESTAMP).max().alias("end_timestamp"),
+ pl.col(Column.FRAME_ID).min().alias("start_frame_id"),
+ pl.col(Column.FRAME_ID).max().alias("end_frame_id"),
+ ]
+ )
+ .sort([Column.PERIOD_ID, segment_id, Column.OBJECT_ID])
+ )
+
+ positions = (
+ (
+ segment_coordinates.group_by(
+ (
+ [
+ Column.GAME_ID,
+ Column.PERIOD_ID,
+ Column.BALL_OWNING_TEAM_ID,
+ ]
+ + [segment_id]
+ if self._every != "period"
+ else [
+ Column.GAME_ID,
+ Column.PERIOD_ID,
+ Column.BALL_OWNING_TEAM_ID,
+ ]
+ ),
+ maintain_order=True,
+ )
+ .agg(
+ pl.map_groups(
+ exprs=self._exprs_variables,
+ function=lambda group: self._compute(group),
+ return_dtype=pl.Struct,
+ ).alias("result")
+ )
+ .unnest("result")
+ )
+ .explode([Column.OBJECT_ID, Column.TEAM_ID, "position", "formation"])
+ .with_columns(
+ pl.when((pl.col(Column.OBJECT_ID) == Constant.BALL))
+ .then(None)
+ .when(
+ (pl.col(Column.TEAM_ID) == pl.col(Column.BALL_OWNING_TEAM_ID))
+ )
+ .then(True)
+ .otherwise(False)
+ .alias("is_attacking")
+ )
+ )
+
+ if __added_arbitrary_base:
+ positions = positions.with_columns(
+ (pl.col(segment_id) - pl.lit(base_time))
+ .cast(pl.Duration)
+ .alias(segment_id)
+ )
+
+ self.output = positions.sort([segment_id, "is_attacking", Column.OBJECT_ID])
+
+ self.segments = (
+ segment_coordinates.select(
+ [
+ segment_id,
+ "n_frames",
+ "start_timestamp",
+ "end_timestamp",
+ "start_frame_id",
+ "end_frame_id",
+ ]
+ )
+ .unique()
+ .sort([segment_id])
+ )
+ self._fit = True
+ return self
diff --git a/unravel/utils/objects/default_settings.py b/unravel/utils/objects/default_settings.py
index 75387cf3..5edad172 100644
--- a/unravel/utils/objects/default_settings.py
+++ b/unravel/utils/objects/default_settings.py
@@ -42,6 +42,7 @@ class DefaultSettings:
max_player_acceleration: float = 6.0
max_ball_acceleration: float = 13.5
ball_carrier_threshold: float = 25.0
+ frame_rate: int = 25
def to_dict(self) -> Dict[str, Any]:
"""Convert the dataclass instance to a dictionary.