From 80ada539e0541187a6428bb549cb9616325ef0a0 Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Thu, 22 May 2025 12:27:18 +0200 Subject: [PATCH 1/6] gpu for american football + version --- unravel/__init__.py | 2 +- unravel/american_football/graphs/graph_converter.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/unravel/__init__.py b/unravel/__init__.py index b235f047..2e3377d7 100644 --- a/unravel/__init__.py +++ b/unravel/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.3.0" +__version__ = "0.4.0" from .soccer import * from .american_football import * diff --git a/unravel/american_football/graphs/graph_converter.py b/unravel/american_football/graphs/graph_converter.py index 66c2232a..b08b5fdb 100644 --- a/unravel/american_football/graphs/graph_converter.py +++ b/unravel/american_football/graphs/graph_converter.py @@ -364,7 +364,9 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]: graph_df = self._convert() self.graph_frames = [ graph - for chunk in graph_df.lazy().collect().iter_slices(self.chunk_size) + for chunk in graph_df.lazy() + .collect(engine="gpu") + .iter_slices(self.chunk_size) for graph in process_chunk(chunk) ] return self.graph_frames From 904033b5cf7946139568556242441e7d1b39cfd3 Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Fri, 23 May 2025 09:21:24 +0200 Subject: [PATCH 2/6] deprecated SoccerGraphConverter, renamed to SoccerGraphConverter --- ...est_bigdb.py => test_american_football.py} | 0 tests/test_kloppy.py | 284 ---- .../{test_kloppy_polars.py => test_soccer.py} | 60 +- tests/test_spektral.py | 101 +- .../graphs/graph_converter.py | 10 - unravel/soccer/graphs/__init__.py | 5 +- unravel/soccer/graphs/graph_converter.py | 1248 ++++++++++++++--- unravel/soccer/graphs/graph_converter_pl.py | 1162 --------------- unravel/soccer/graphs/graph_frame.py | 104 -- unravel/soccer/graphs/graph_settings.py | 36 +- unravel/soccer/graphs/graph_settings_pl.py | 57 - .../utils/objects/default_graph_converter.py | 11 +- 12 files changed, 1186 insertions(+), 1892 deletions(-) rename tests/{test_bigdb.py => test_american_football.py} (100%) delete mode 100644 tests/test_kloppy.py rename tests/{test_kloppy_polars.py => test_soccer.py} (96%) delete mode 100644 unravel/soccer/graphs/graph_converter_pl.py delete mode 100644 unravel/soccer/graphs/graph_frame.py delete mode 100644 unravel/soccer/graphs/graph_settings_pl.py diff --git a/tests/test_bigdb.py b/tests/test_american_football.py similarity index 100% rename from tests/test_bigdb.py rename to tests/test_american_football.py diff --git a/tests/test_kloppy.py b/tests/test_kloppy.py deleted file mode 100644 index 4fb2273a..00000000 --- a/tests/test_kloppy.py +++ /dev/null @@ -1,284 +0,0 @@ -from pathlib import Path -from unravel.soccer import SoccerGraphConverter, GraphFrame, SoccerGraphSettings -from unravel.utils import ( - DefaultTrackingModel, - dummy_labels, - dummy_graph_ids, - CustomSpektralDataset, -) - -from kloppy import skillcorner -from kloppy.domain import Ground, TrackingDataset, Orientation -from typing import List, Dict - -from spektral.data import Graph - -import pytest - -import numpy as np - - -class TestKloppyData: - - @pytest.fixture - def match_data(self, base_dir: Path) -> str: - return base_dir / "files" / "skillcorner_match_data.json" - - @pytest.fixture - def structured_data(self, base_dir: Path) -> str: - return base_dir / "files" / "skillcorner_structured_data.json.gz" - - @pytest.fixture() - def dataset(self, match_data: str, structured_data: str) -> TrackingDataset: - return skillcorner.load( - raw_data=structured_data, - meta_data=match_data, - coordinates="tracab", - include_empty_frames=False, - limit=500, - ) - - @pytest.fixture() - def gnnc(self, dataset: TrackingDataset) -> SoccerGraphConverter: - return SoccerGraphConverter( - dataset=dataset, - labels=dummy_labels(dataset), - graph_ids=dummy_graph_ids(dataset), - ball_carrier_treshold=25.0, - max_player_speed=12.0, - max_ball_speed=28.0, - boundary_correction=None, - self_loop_ball=True, - adjacency_matrix_connect_type="ball", - adjacency_matrix_type="split_by_team", - label_type="binary", - defending_team_node_value=0.0, - non_potential_receiver_node_value=0.1, - infer_ball_ownership=True, - infer_goalkeepers=True, - random_seed=False, - pad=False, - verbose=False, - ) - - @pytest.fixture() - def gnnc_padding(self, dataset: TrackingDataset) -> SoccerGraphConverter: - return SoccerGraphConverter( - dataset=dataset, - labels=dummy_labels(dataset), - graph_id=1234, - ball_carrier_treshold=25.0, - max_player_speed=12.0, - max_ball_speed=28.0, - boundary_correction=None, - self_loop_ball=False, - adjacency_matrix_connect_type="ball", - adjacency_matrix_type="split_by_team", - label_type="binary", - defending_team_node_value=0.0, - non_potential_receiver_node_value=0.1, - infer_ball_ownership=True, - infer_goalkeepers=True, - random_seed=False, - pad=True, - verbose=False, - ) - - @pytest.fixture() - def gnnc_padding_random(self, dataset: TrackingDataset) -> SoccerGraphConverter: - return SoccerGraphConverter( - dataset=dataset, - labels=dummy_labels(dataset), - # settings - ball_carrier_treshold=25.0, - max_player_speed=12.0, - max_ball_speed=28.0, - boundary_correction=None, - self_loop_ball=False, - adjacency_matrix_connect_type="ball", - adjacency_matrix_type="split_by_team", - label_type="binary", - defending_team_node_value=0.0, - non_potential_receiver_node_value=0.1, - infer_ball_ownership=True, - infer_goalkeepers=True, - random_seed=42, - pad=True, - verbose=False, - ) - - def test_conversion(self, gnnc: SoccerGraphConverter): - data, label, frame_id, _ = gnnc._convert(gnnc.dataset[2]) - - assert isinstance(data, DefaultTrackingModel) - assert frame_id == 1525 - - assert data.attacking_team == Ground.HOME - assert data.orientation == Orientation.STATIC_HOME_AWAY - assert data.attacking_players == data.home_players - assert data.defending_players == data.away_players - - hp = data.home_players[3] - assert -19.582426479899993 == pytest.approx(hp.x1, abs=1e-5) - assert 24.3039460863 == pytest.approx(hp.y1, abs=1e-5) - assert -19.6022318885 == pytest.approx(hp.x2, abs=1e-5) - assert 24.1632567814 == pytest.approx(hp.y2, abs=1e-5) - assert hp.position.shape == (2,) - np.testing.assert_allclose( - hp.position, np.asarray([hp.x1, hp.y1]), rtol=1e-4, atol=1e-4 - ) - assert hp.is_gk == False - assert hp.next_position[0] - hp.position[0] - - assert data.ball_carrier_idx == 1 - assert len(data.home_players) == 6 - assert len(data.away_players) == 4 - - defending_team_value_node_idx = 10 - non_potential_receiver_node_idx = 11 - - gnn_frame = GraphFrame( - frame_id=frame_id, - data=data, - label=label, - graph_id="abcdefg", - settings=gnnc.settings, - ) - x = gnn_frame.graph_data.get("x") - gid = gnn_frame.graph_data.get("id") - - assert x[9, defending_team_value_node_idx] == 0.0 - assert x[1, non_potential_receiver_node_idx] == 0.1 - - assert gid == "abcdefg" - - def test_conversion_padding(self, gnnc_padding: SoccerGraphConverter): - data, _, frame_id, graph_id = gnnc_padding._convert(gnnc_padding.dataset[2]) - - assert isinstance(data, DefaultTrackingModel) - assert frame_id == 1525 - assert graph_id == 1234 - - assert data.attacking_team == Ground.HOME - assert data.attacking_players == data.home_players - assert data.defending_players == data.away_players - assert data.ball_carrier_idx == 1 - assert len(data.home_players) == 11 - assert len(data.away_players) == 11 - - def test_to_spektral_graph(self, gnnc: SoccerGraphConverter): - """ - Test navigating (next/prev) through events - """ - spektral_graphs = gnnc.to_spektral_graphs() - - assert 1 == 1 - - data = spektral_graphs - assert len(data) == 387 - assert isinstance(data[0], Graph) - # note: these shape tests fail if we add more features (ie. acceleration) - - x = data[0].x - assert x.shape == (10, 12) - assert -0.42531483968190475 == pytest.approx(x[0, 0], abs=1e-5) - assert 0.188 == pytest.approx(x[0, 4], abs=1e-5) - assert 0.5614587302341536 == pytest.approx(x[8, 2], abs=1e-5) - - e = data[0].e - assert e.shape == (60, 7) - assert 0.0 == pytest.approx(e[0, 0], abs=1e-5) - assert 0.5 == pytest.approx(e[0, 4], abs=1e-5) - assert 0.31674592566440973 == pytest.approx(e[8, 2], abs=1e-5) - - a = data[0].a - assert a.shape == (10, 10) - assert 1.0 == pytest.approx(a[0, 0], abs=1e-5) - assert 1.0 == pytest.approx(a[0, 4], abs=1e-5) - assert 0.0 == pytest.approx(a[8, 2], abs=1e-5) - - dataset = CustomSpektralDataset(graphs=spektral_graphs) - N, F, S, n_out, n = dataset.dimensions() - assert N == 21 - assert F == 12 - assert S == 7 - assert n_out == 1 - assert n == 387 - - train, test, val = dataset.split_test_train_validation( - split_train=4, - split_test=1, - split_validation=1, - by_graph_id=True, - random_seed=42, - ) - assert train.n_graphs == 233 - assert test.n_graphs == 77 - assert val.n_graphs == 77 - - train, test, val = dataset.split_test_train_validation( - split_train=4, - split_test=1, - split_validation=1, - by_graph_id=False, - random_seed=42, - ) - assert train.n_graphs == 258 - assert test.n_graphs == 64 - assert val.n_graphs == 65 - - train, test = dataset.split_test_train( - split_train=4, split_test=1, by_graph_id=False, random_seed=42 - ) - assert train.n_graphs == 309 - assert test.n_graphs == 78 - - train, test = dataset.split_test_train( - split_train=4, split_test=5, by_graph_id=False, random_seed=42 - ) - assert train.n_graphs == 172 - assert test.n_graphs == 215 - - with pytest.raises( - NotImplementedError, - match="Make sure split_train > split_test >= split_validation, other behaviour is not supported when by_graph_id is True...", - ): - dataset.split_test_train( - split_train=4, split_test=5, by_graph_id=True, random_seed=42 - ) - - def test_to_spektral_graph_padding_random( - self, gnnc_padding_random: SoccerGraphConverter - ): - """ - Test navigating (next/prev) through events - """ - gnnc_padding_random.to_graph_frames() - - spektral_graphs = [ - g.to_spektral_graph() for g in gnnc_padding_random.graph_frames - ] - - # with random seed = 42 the permuntation is [15 9 0 8 17 12 1 13 5 2 11 20 3 4 18 16 21 22 7 10 14 19 6] - assert 1 == 1 - - data = spektral_graphs - assert len(data) == 387 - assert isinstance(data[0], Graph) - # note: these shape tests fail if we add more features (ie. acceleration) - - x = data[0].x - assert x.shape == (23, 12) - assert -0.42531483968190475 == pytest.approx(x[2, 0], abs=1e-5) - assert 0.188 == pytest.approx(x[2, 4], abs=1e-5) - assert 0.5614587302341536 == pytest.approx(x[20, 2], abs=1e-5) - - e = data[0].e - assert e.shape == (287, 7) - assert 0.0 == pytest.approx(e[0, 0], abs=1e-5) - assert 0.4261188174 == pytest.approx(e[75, 4], abs=1e-5) - assert 0.31674592566440973 == pytest.approx(e[119, 2], abs=1e-5) - - a = data[0].a - assert a.shape == (23, 23) - assert 1.0 == pytest.approx(a[2, 2], abs=1e-5) diff --git a/tests/test_kloppy_polars.py b/tests/test_soccer.py similarity index 96% rename from tests/test_kloppy_polars.py rename to tests/test_soccer.py index e15357fb..e55948ca 100644 --- a/tests/test_kloppy_polars.py +++ b/tests/test_soccer.py @@ -1,6 +1,6 @@ from pathlib import Path from unravel.soccer import ( - SoccerGraphConverterPolars, + SoccerGraphConverter, KloppyPolarsDataset, PressingIntensity, Constant, @@ -165,7 +165,7 @@ def kloppy_polars_dataset( @pytest.fixture() def spc_padding( self, kloppy_polars_dataset: KloppyPolarsDataset - ) -> SoccerGraphConverterPolars: + ) -> SoccerGraphConverter: ds = kloppy_polars_dataset ds.data = ds.data.with_columns( [pl.lit(1.0).alias("fake_global_feature_column")] @@ -198,7 +198,7 @@ def spc_padding( .drop(["ball_x", "ball_y", "ball_z"]) ) - return SoccerGraphConverterPolars( + return SoccerGraphConverter( dataset=kloppy_polars_dataset, chunk_size=2_0000, non_potential_receiver_node_value=0.1, @@ -217,9 +217,9 @@ def spc_padding( @pytest.fixture() def soccer_polars_converter( self, kloppy_polars_dataset: KloppyPolarsDataset - ) -> SoccerGraphConverterPolars: + ) -> SoccerGraphConverter: - return SoccerGraphConverterPolars( + return SoccerGraphConverter( dataset=kloppy_polars_dataset, chunk_size=2_0000, non_potential_receiver_node_value=0.1, @@ -236,15 +236,15 @@ def soccer_polars_converter( @pytest.fixture() def soccer_polars_converter_sportec( self, kloppy_polars_sportec_dataset: KloppyPolarsDataset - ) -> SoccerGraphConverterPolars: + ) -> SoccerGraphConverter: kloppy_polars_sportec_dataset.add_dummy_labels() kloppy_polars_sportec_dataset.add_graph_ids() - return SoccerGraphConverterPolars(dataset=kloppy_polars_sportec_dataset) + return SoccerGraphConverter(dataset=kloppy_polars_sportec_dataset) @pytest.fixture() def soccer_polars_converter_graph_and_additional_features( self, kloppy_polars_dataset: KloppyPolarsDataset - ) -> SoccerGraphConverterPolars: + ) -> SoccerGraphConverter: kloppy_polars_dataset.data = ( kloppy_polars_dataset.data @@ -269,7 +269,7 @@ def custom_edge_feature(**kwargs): def custom_node_feature(**kwargs): return kwargs["fake_additional_feature_a"] - return SoccerGraphConverterPolars( + return SoccerGraphConverter( dataset=kloppy_polars_dataset, global_feature_cols=["fake_graph_feature_a", "fake_graph_feature_b"], additional_feature_cols=["fake_additional_feature_a"], @@ -309,7 +309,7 @@ def custom_node_feature(**kwargs): def test_incorrect_custom_features( self, kloppy_polars_dataset: KloppyPolarsDataset - ) -> SoccerGraphConverterPolars: + ) -> SoccerGraphConverter: kloppy_polars_dataset.data = ( kloppy_polars_dataset.data @@ -331,7 +331,7 @@ def custom_edge_feature(**kwargs): ) with pytest.raises(Exception): - SoccerGraphConverterPolars( + SoccerGraphConverter( dataset=kloppy_polars_dataset, global_feature_cols=["fake_graph_feature_a", "fake_graph_feature_b"], additional_feature_cols=["fake_additional_feature_a"], @@ -356,7 +356,7 @@ def custom_edge_feature(**kwargs): def test_incorrect_custom_features_no_decorator( self, kloppy_polars_dataset: KloppyPolarsDataset - ) -> SoccerGraphConverterPolars: + ) -> SoccerGraphConverter: kloppy_polars_dataset.data = ( kloppy_polars_dataset.data @@ -377,7 +377,7 @@ def custom_edge_feature(**kwargs): ) with pytest.raises(Exception): - SoccerGraphConverterPolars( + SoccerGraphConverter( dataset=kloppy_polars_dataset, global_feature_cols=["fake_graph_feature_a", "fake_graph_feature_b"], additional_feature_cols=["fake_additional_feature_a"], @@ -402,7 +402,7 @@ def custom_edge_feature(**kwargs): def test_node_feature_computation( self, - soccer_polars_converter_sportec: SoccerGraphConverterPolars, + soccer_polars_converter_sportec: SoccerGraphConverter, single_frame: dict, single_frame_node_feature_result: np.ndarray, ): @@ -448,7 +448,7 @@ def test_node_feature_computation( def test_edge_feature_computation( self, - soccer_polars_converter_sportec: SoccerGraphConverterPolars, + soccer_polars_converter_sportec: SoccerGraphConverter, single_frame: dict, single_frame_edge_feature_result: np.ndarray, single_frame_adj_matrix_result: np.ndarray, @@ -788,7 +788,7 @@ def test_pi_full_include_ball_owning_speed_0( count = np.count_nonzero(np.isclose(arr, 0.0, atol=1e-5)) assert count == 117 - def test_padding(self, spc_padding: SoccerGraphConverterPolars): + def test_padding(self, spc_padding: SoccerGraphConverter): spektral_graphs = spc_padding.to_spektral_graphs() assert 1 == 1 @@ -800,7 +800,7 @@ def test_padding(self, spc_padding: SoccerGraphConverterPolars): assert len(data) == 245 assert isinstance(data[0], Graph) - def spektral_graph(self, soccer_polars_converter: SoccerGraphConverterPolars): + def spektral_graph(self, soccer_polars_converter: SoccerGraphConverter): """ Test navigating (next/prev) through events """ @@ -899,7 +899,7 @@ def spektral_graph(self, soccer_polars_converter: SoccerGraphConverterPolars): def test_to_spektral_graph_level_features( self, - soccer_polars_converter_graph_and_additional_features: SoccerGraphConverterPolars, + soccer_polars_converter_graph_and_additional_features: SoccerGraphConverter, single_frame_node_feature_global_result_file: str, ): """ @@ -974,7 +974,7 @@ def test_line_method(self): assert np.array_equal(valid_mask, np.array([True, True, False, False])) - def test_plot_graph(self, soccer_polars_converter: SoccerGraphConverterPolars): + def test_plot_graph(self, soccer_polars_converter: SoccerGraphConverter): plot_path = join("tests", "files", "plot", "test-1.mp4") soccer_polars_converter.plot( @@ -989,9 +989,7 @@ def test_plot_graph(self, soccer_polars_converter: SoccerGraphConverterPolars): color_by="ball_owning", ) - def test_plot_png_success( - self, soccer_polars_converter: SoccerGraphConverterPolars - ): + def test_plot_png_success(self, soccer_polars_converter: SoccerGraphConverter): """Test successful PNG generation with correct parameters.""" # Setup test file path plot_path = os.path.join("tests", "files", "plot", "test-png.png") @@ -1015,9 +1013,7 @@ def test_plot_png_success( assert os.path.exists(plot_path) assert plot_path.endswith(".png") - def test_plot_png_no_extension( - self, soccer_polars_converter: SoccerGraphConverterPolars - ): + def test_plot_png_no_extension(self, soccer_polars_converter: SoccerGraphConverter): """Test PNG generation when no file extension is provided.""" # Setup test file path without extension plot_path = os.path.join("tests", "files", "plot", "test-no-extension") @@ -1042,9 +1038,7 @@ def test_plot_png_no_extension( # Check that the file was created with .png extension assert os.path.exists(expected_path) - def test_plot_error_only_fps( - self, soccer_polars_converter: SoccerGraphConverterPolars - ): + def test_plot_error_only_fps(self, soccer_polars_converter: SoccerGraphConverter): """Test error is raised when only fps is provided without end_timestamp.""" with pytest.raises(ValueError): soccer_polars_converter.plot( @@ -1055,7 +1049,7 @@ def test_plot_error_only_fps( ) def test_plot_error_empty_selection( - self, soccer_polars_converter: SoccerGraphConverterPolars + self, soccer_polars_converter: SoccerGraphConverter ): with pytest.raises(ValueError): soccer_polars_converter.plot( @@ -1065,7 +1059,7 @@ def test_plot_error_empty_selection( ) def test_plot_error_only_end_timestamp( - self, soccer_polars_converter: SoccerGraphConverterPolars + self, soccer_polars_converter: SoccerGraphConverter ): """Test error is raised when only end_timestamp is provided without fps.""" with pytest.raises(ValueError): @@ -1077,7 +1071,7 @@ def test_plot_error_only_end_timestamp( ) def test_plot_error_mp4_extension_without_video_params( - self, soccer_polars_converter: SoccerGraphConverterPolars + self, soccer_polars_converter: SoccerGraphConverter ): """Test error when .mp4 extension is used but video parameters are not provided.""" with pytest.raises(ValueError): @@ -1089,7 +1083,7 @@ def test_plot_error_mp4_extension_without_video_params( ) def test_plot_error_wrong_extension_for_png( - self, soccer_polars_converter: SoccerGraphConverterPolars + self, soccer_polars_converter: SoccerGraphConverter ): """Test error when non-png/mp4 extension is used for image output.""" with pytest.raises(ValueError): @@ -1100,7 +1094,7 @@ def test_plot_error_wrong_extension_for_png( ) def test_plot_error_wrong_extension_for_mp4( - self, soccer_polars_converter: SoccerGraphConverterPolars + self, soccer_polars_converter: SoccerGraphConverter ): """Test error when non-mp4 extension is used for video output.""" with pytest.raises(ValueError): diff --git a/tests/test_spektral.py b/tests/test_spektral.py index db273f90..a4400b36 100644 --- a/tests/test_spektral.py +++ b/tests/test_spektral.py @@ -1,5 +1,5 @@ from pathlib import Path -from unravel.soccer import SoccerGraphConverter +from unravel.soccer import KloppyPolarsDataset, SoccerGraphConverter from unravel.american_football import BigDataBowlDataset, AmericanFootballGraphConverter from unravel.utils import dummy_labels, dummy_graph_ids, CustomSpektralDataset from unravel.classifiers import CrystalGraphClassifier @@ -71,49 +71,98 @@ def kloppy_dataset(self, match_data: str, structured_data: str) -> TrackingDatas ) @pytest.fixture() - def soccer_converter(self, kloppy_dataset: TrackingDataset) -> SoccerGraphConverter: - return SoccerGraphConverter( - dataset=kloppy_dataset, - labels=dummy_labels(kloppy_dataset), - graph_ids=dummy_graph_ids(kloppy_dataset), - ball_carrier_treshold=25.0, + def kloppy_polars_dataset( + self, kloppy_dataset: TrackingDataset + ) -> KloppyPolarsDataset: + dataset = KloppyPolarsDataset( + kloppy_dataset=kloppy_dataset, + ball_carrier_threshold=25.0, max_player_speed=12.0, - max_ball_speed=28.0, - boundary_correction=None, + max_player_acceleration=12.0, + max_ball_speed=13.5, + max_ball_acceleration=100, + ) + dataset.add_dummy_labels(by=["game_id", "frame_id"], random_seed=42) + dataset.add_graph_ids(by=["game_id", "frame_id"]) + return dataset + + @pytest.fixture() + def soccer_converter( + self, kloppy_polars_dataset: KloppyPolarsDataset + ) -> SoccerGraphConverter: + # return SoccerGraphConverterDeprecated( + # dataset=kloppy_dataset, + # labels=dummy_labels(kloppy_dataset), + # graph_ids=dummy_graph_ids(kloppy_dataset), + # ball_carrier_treshold=25.0, + # max_player_speed=12.0, + # max_ball_speed=28.0, + # boundary_correction=None, + # self_loop_ball=True, + # adjacency_matrix_connect_type="ball", + # adjacency_matrix_type="split_by_team", + # label_type="binary", + # defending_team_node_value=0.0, + # non_potential_receiver_node_value=0.1, + # infer_ball_ownership=True, + # infer_goalkeepers=True, + # random_seed=42, + # pad=False, + # verbose=False, + # ) + return SoccerGraphConverter( + dataset=kloppy_polars_dataset, + chunk_size=2_0000, + non_potential_receiver_node_value=0.1, self_loop_ball=True, adjacency_matrix_connect_type="ball", adjacency_matrix_type="split_by_team", label_type="binary", defending_team_node_value=0.0, - non_potential_receiver_node_value=0.1, - infer_ball_ownership=True, - infer_goalkeepers=True, random_seed=42, - pad=False, + pad=True, verbose=False, ) @pytest.fixture() def soccer_converter_preds( - self, kloppy_dataset: TrackingDataset + self, kloppy_polars_dataset: KloppyPolarsDataset ) -> SoccerGraphConverter: + # @pytest.fixture() + # def soccer_converter_preds( + # self, kloppy_dataset: TrackingDataset + # ) -> SoccerGraphConverterDeprecated: + # return SoccerGraphConverterDeprecated( + # dataset=kloppy_dataset, + # prediction=True, + # ball_carrier_treshold=25.0, + # max_player_speed=12.0, + # max_ball_speed=28.0, + # boundary_correction=None, + # self_loop_ball=True, + # adjacency_matrix_connect_type="ball", + # adjacency_matrix_type="split_by_team", + # label_type="binary", + # defending_team_node_value=0.0, + # non_potential_receiver_node_value=0.1, + # infer_ball_ownership=True, + # infer_goalkeepers=True, + # random_seed=42, + # pad=False, + # verbose=False, + # ) return SoccerGraphConverter( - dataset=kloppy_dataset, + dataset=kloppy_polars_dataset, prediction=True, - ball_carrier_treshold=25.0, - max_player_speed=12.0, - max_ball_speed=28.0, - boundary_correction=None, + chunk_size=2_0000, + non_potential_receiver_node_value=0.1, self_loop_ball=True, adjacency_matrix_connect_type="ball", adjacency_matrix_type="split_by_team", label_type="binary", defending_team_node_value=0.0, - non_potential_receiver_node_value=0.1, - infer_ball_ownership=True, - infer_goalkeepers=True, random_seed=42, - pad=False, + pad=True, verbose=False, ) @@ -214,10 +263,10 @@ def test_soccer_prediction(self, soccer_converter_preds: SoccerGraphConverter): df = pd.DataFrame( {"frame_id": [x.id for x in pred_dataset], "y": preds.flatten()} - ) + ).sort_values(by=["frame_id"]) - assert df["frame_id"].iloc[0] == 1524 - assert df["frame_id"].iloc[-1] == 1621 + assert df["frame_id"].iloc[0] == "2417-1524" + assert df["frame_id"].iloc[-1] == "2417-1622" def test_bdb_training(self, bdb_converter: AmericanFootballGraphConverter): train = CustomSpektralDataset(graphs=bdb_converter.to_spektral_graphs()) diff --git a/unravel/american_football/graphs/graph_converter.py b/unravel/american_football/graphs/graph_converter.py index b08b5fdb..ae8310a6 100644 --- a/unravel/american_football/graphs/graph_converter.py +++ b/unravel/american_football/graphs/graph_converter.py @@ -84,16 +84,6 @@ def _sample(self): pl.col(Column.FRAME_ID) % (1.0 / self.sample_rate) == 0 ) - def _shuffle(self): - if isinstance(self.settings.random_seed, int): - self.dataset = self.dataset.sample( - fraction=1.0, seed=self.settings.random_seed - ) - elif self.settings.random_seed == True: - self.dataset = self.dataset.sample(fraction=1.0) - else: - pass - def _sport_specific_checks(self): def __remove_with_missing_values(min_object_count: int = 10): cs = ( diff --git a/unravel/soccer/graphs/__init__.py b/unravel/soccer/graphs/__init__.py index 843e4057..61651d94 100644 --- a/unravel/soccer/graphs/__init__.py +++ b/unravel/soccer/graphs/__init__.py @@ -1,7 +1,4 @@ from .graph_converter import SoccerGraphConverter -from .graph_converter_pl import SoccerGraphConverterPolars -from .graph_settings import SoccerGraphSettings -from .graph_settings_pl import GraphSettingsPolars -from .graph_frame import GraphFrame +from .graph_settings import GraphSettingsPolars from .exceptions import * from .features import * diff --git a/unravel/soccer/graphs/graph_converter.py b/unravel/soccer/graphs/graph_converter.py index a958fffc..b95c5711 100644 --- a/unravel/soccer/graphs/graph_converter.py +++ b/unravel/soccer/graphs/graph_converter.py @@ -1,41 +1,29 @@ import logging import sys -from copy import deepcopy -from scipy.spatial.qhull import QhullError +from dataclasses import dataclass -from warnings import warn, simplefilter +from typing import List, Union, Dict, Literal, Any, Optional, Callable -from dataclasses import dataclass, field, asdict +import inspect -from typing import List, Union, Dict, Literal +import pathlib -from kloppy.domain import ( - TrackingDataset, - Frame, - Orientation, - DatasetTransformer, - DatasetFlag, - SecondSpectrumCoordinateSystem, -) +from kloppy.domain import MetricPitchDimensions, Orientation from spektral.data import Graph -from .exceptions import ( - MissingLabelsError, - MissingDatasetError, - IncorrectDatasetTypeError, - KeyMismatchError, +from .graph_settings import GraphSettingsPolars +from ..dataset.kloppy_polars import KloppyPolarsDataset, Column, Group, Constant +from .features import ( + compute_node_features, + add_global_features, + compute_adjacency_matrix, + compute_edge_features, ) -from .graph_settings import SoccerGraphSettings -from .graph_frame import GraphFrame - from ...utils import * -simplefilter("always", DeprecationWarning) - - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) stdout_handler = logging.StreamHandler(sys.stdout) @@ -48,71 +36,375 @@ class SoccerGraphConverter(DefaultGraphConverter): Converts our dataset TrackingDataset into an internal structure Attributes: - dataset (TrackingDataset): Kloppy TrackingDataset. - labels (dict): Dict with a key per frame_id, like so {frame_id: True/False/1/0} - graph_id (str, int): Set a single id for the whole Kloppy dataset. - graph_ids (dict): Frame level control over graph ids. - - The graph_ids will be used to assign each graph an identifier. This identifier allows us to split the CustomSpektralDataset such that - all graphs with the same id are either all in the test, train or validation set to avoid leakage. It is recommended to either set graph_id (int, str) as - a match_id, or pass a dictionary into 'graph_ids' with exactly the same keys as 'labels' for more granualar control over the graph ids. - The latter can be useful when splitting graphs by possession or sequence id. In this case the dict would be {frame_id: sequence_id/possession_id}. - Note that sequence_id/possession_id should probably be unique for the whole dataset. Perhaps like so {frame_id: 'match_id-sequence_id'}. Defaults to None. - - infer_ball_ownership (bool): - Infers 'attacking_team' if no 'ball_owning_team' (Kloppy) or 'attacking_team' (List[Dict]) is provided, by finding player closest to ball using ball xyz. - Also infers ball_carrier within ball_carrier_threshold - infer_goalkeepers (bool): set True if no GK label is provider, set False for incomplete (broadcast tracking) data that might not have a GK in every frame - max_ball_speed (float): The maximum speed of the ball in meters per second. Defaults to 28.0. - max_player_speed (float): The maximum speed of a player in meters per second. Defaults to 12.0. - max_ball_speed (float): The maximum speed of the ball in meters per second. Defaults to 28.0. - ball_carrier_threshold (float): The distance threshold to determine the ball carrier. Defaults to 25.0. - boundary_correction (float): A correction factor for boundary calculations, used to correct out of bounds as a percentages (Used as 1+boundary_correction, ie 0.05). Defaults to None. + dataset (KloppyPolarsDataset): KloppyPolarsDataset created from a Kloppy dataset. + chunk_size (int): Determines how many Graphs get processed simultanously. non_potential_receiver_node_value (float): Value between 0 and 1 to assign to the defing team players + global_feature_cols (list[str]): List of columns in the dataset that are Graph level features (e.g. team strength rating, win probabilities etc) + we want to add to our model. A list of column names corresponding to the Polars dataframe within KloppyPolarsDataset.data + that are graph level features. They should be joined to the KloppyPolarsDataset.data dataframe such that + each Group in the group_by has the same value per column. We take the first value of the group, and assign this as a + "graph level feature" to the ball node. + global_feature_type: A literal of type "ball" or "all". When set to "ball" the global features will be assigned to only the ball node, if set to "all" + the they will be assigned to every player and ball in the node features. + edge_feature_funcs: A list of functions (decorated with @graph_feature(is_custom, feature_type="edge")) + that take **kwargs as input and return a numpy array (dimensions should match expected (N,N) shape or tuple with multipe (N, N) numpy arrays). + node_feature_funcs: A list of functions (decorated with @graph_feature(is_custom, feature_type="node")) + that take **kwargs as input and return a numpy array (dimensions should match expected (N,) shape or (N, k) ) + additional_feature_cols: Column the user has added to the 'KloppyPolarsDataset.data' that are not to be added as global features, + but can now be accessed by edge_feature_funcs and node_feature_funcs through kwargs. + (e.g. if the user adds "height" for each player, as a column to the 'KloppyPolarsDataset.data' and + they want to use it to compute the height difference between all players as an edge feature they would + pass additional_feature_cols=["height"] and their custom edge feature function can now access kwargs['height']) """ - dataset: TrackingDataset = None - labels: dict = None + dataset: KloppyPolarsDataset = None - graph_id: Union[str, int, dict] = None - graph_ids: dict = None + chunk_size: int = 2_0000 + non_potential_receiver_node_value: float = 0.1 - infer_goalkeepers: bool = True - infer_ball_ownership: bool = True - boundary_correction: float = None + edge_feature_funcs: List[Callable[[Dict[str, Any]], np.ndarray]] = field( + repr=False, default_factory=list + ) + node_feature_funcs: List[Callable[[Dict[str, Any]], np.ndarray]] = field( + repr=False, default_factory=list + ) - max_player_speed: float = 12.0 - max_ball_speed: float = 28.0 - # max_player_acceleration: float = 6.0 - # max_ball_acceleration: float = 13.5 - ball_carrier_treshold: float = 25.0 + global_feature_cols: Optional[List[str]] = field(repr=False, default_factory=list) + global_feature_type: Literal["ball", "all"] = "ball" - non_potential_receiver_node_value: float = 0.1 + additional_feature_cols: Optional[List[str]] = field( + repr=False, default_factory=list + ) + + _edge_feature_dims: Dict[str, int] = field( + repr=False, default_factory=dict, init=False + ) + _node_feature_dims: Dict[str, int] = field( + repr=False, default_factory=dict, init=False + ) def __post_init__(self): - warn( - """ - This class is deprecated and will be removed in a future release. Please use SoccerGraphConverterPolars instead. - Note: SoccerGraphConverterPolars is not one-to-one compatible with models and dataset created from SoccerGraphConverter due to breaking changes. - """, - category=DeprecationWarning, - stacklevel=2, + if not isinstance(self.dataset, KloppyPolarsDataset): + raise ValueError("dataset should be of type KloppyPolarsDataset...") + + self.pitch_dimensions: MetricPitchDimensions = ( + self.dataset.settings.pitch_dimensions + ) + self._kloppy_settings = self.dataset.settings + + self.label_column: str = ( + self.label_col if self.label_col is not None else self.dataset._label_column ) - if not self.dataset: - raise Exception("Please provide a 'kloppy' dataset.") + self.graph_id_column: str = ( + self.graph_id_col + if self.graph_id_col is not None + else self.dataset._graph_id_column + ) + + self.dataset = self.dataset.data + + if not self.edge_feature_funcs: + self.edge_feature_funcs = self.default_edge_feature_funcs + + self._verify_feature_funcs(self.edge_feature_funcs, feature_type="edge") + + if not self.node_feature_funcs: + self.node_feature_funcs = self.default_node_feature_funcs + + self._verify_feature_funcs(self.node_feature_funcs, feature_type="node") self._sport_specific_checks() - self.settings = SoccerGraphSettings( - ball_carrier_treshold=self.ball_carrier_treshold, - max_player_speed=self.max_player_speed, - max_ball_speed=self.max_ball_speed, - boundary_correction=self.boundary_correction, + self.settings = self._apply_graph_settings() + + if self.pad: + self.dataset = self._apply_padding() + else: + self.dataset = self._remove_incomplete_frames() + + self._sample() + self._shuffle() + + def _sample(self): + if self.sample_rate is None: + return + else: + self.dataset = self.dataset.filter( + pl.col(Column.FRAME_ID) % (1.0 / self.sample_rate) == 0 + ) + + def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]): + for i, func in enumerate(funcs): + # Check if it has the attributes added by the decorator + if not hasattr(func, "feature_type"): + func_str = inspect.getsource(func).strip() + raise Exception( + f"Error processing feature function:\n" + f"{func.__name__} defined as:\n" + f"{func_str}\n\n" + "Function is missing the @graph_feature decorator. " + ) + + if func.feature_type != feature_type: + func_str = inspect.getsource(func).strip() + raise Exception( + f"Error processing feature function:\n" + f"{func.__name__} defined as:\n" + f"{func_str}\n\n" + "Function has an incorrect feature type edge features should be 'edge', node features should be 'node'. " + ) + + @staticmethod + def _sort(df): + sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - ( + (pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID)) + & (pl.col(Column.TEAM_ID) != Constant.BALL) + ).cast(int) + + df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)]) + return df + + def _remove_incomplete_frames(self) -> pl.DataFrame: + df = self.dataset + total_frames = len(df.unique(Group.BY_FRAME)) + + valid_frames = ( + df.group_by(Group.BY_FRAME) + .agg(pl.col(Column.TEAM_ID).n_unique().alias("unique_teams")) + .filter(pl.col("unique_teams") == 3) + .select(Group.BY_FRAME) + ) + dropped_frames = total_frames - len(valid_frames.unique(Group.BY_FRAME)) + if dropped_frames > 0 and self.verbose: + self.__warn_dropped_frames(dropped_frames, total_frames) + + return df.join(valid_frames, on=Group.BY_FRAME) + + def _apply_padding(self) -> pl.DataFrame: + df = self.dataset + + keep_columns = [ + Column.TIMESTAMP, + Column.BALL_STATE, + self.label_column, + self.graph_id_column, + ] + empty_columns = [ + Column.POSITION_NAME, + Column.OBJECT_ID, + Column.IS_BALL_CARRIER, + Column.X, + Column.Y, + Column.Z, + Column.VX, + Column.VY, + Column.VZ, + Column.SPEED, + Column.AX, + Column.AY, + Column.AZ, + Column.ACCELERATION, + ] + group_by_columns = [ + Column.GAME_ID, + Column.PERIOD_ID, + Column.FRAME_ID, + Column.TEAM_ID, + Column.BALL_OWNING_TEAM_ID, + ] + + user_defined_columns = [ + x + for x in df.columns + if x + not in keep_columns + + group_by_columns + + empty_columns + + self.global_feature_cols + ] + + counts = df.group_by(group_by_columns).agg( + pl.len().alias("count"), + *[ + pl.first(col).alias(col) + for col in keep_columns + self.global_feature_cols + ], + ) + + counts = counts.with_columns( + [ + pl.when(pl.col(Column.TEAM_ID) == Constant.BALL) + .then(1) + .when(pl.col(Column.TEAM_ID) == pl.col(Column.BALL_OWNING_TEAM_ID)) + .then(11) + .otherwise(11) + .alias("target_length") + ] + ) + + groups_to_pad = counts.filter( + pl.col("count") < pl.col("target_length") + ).with_columns((pl.col("target_length") - pl.col("count")).alias("repeats")) + + padding_rows = [] + # This is where we pad players (missing balls get skipped because of 'target_length') + for row in groups_to_pad.iter_rows(named=True): + base_row = { + col: row[col] + for col in keep_columns + group_by_columns + self.global_feature_cols + } + padding_rows.extend([base_row] * row["repeats"]) + + # Now check if there are frames without ball rows + # Get all unique frames + all_frames = df.select( + [ + Column.GAME_ID, + Column.PERIOD_ID, + Column.FRAME_ID, + Column.BALL_OWNING_TEAM_ID, + ] + + keep_columns + + self.global_feature_cols + ).unique() + + # Get frames that have ball rows + frames_with_ball = ( + df.filter(pl.col(Column.TEAM_ID) == Constant.BALL) + .select([Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID]) + .unique() + ) + + # Find frames missing ball rows + frames_missing_ball = all_frames.join( + frames_with_ball, + on=[Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID], + how="anti", + ) + + # Create a dataframe of ball rows to add with appropriate columns + if frames_missing_ball.height > 0: + # Create base rows for missing balls + ball_rows_to_add = frames_missing_ball.with_columns( + [ + pl.lit(Constant.BALL).alias(Column.TEAM_ID), + pl.lit(Constant.BALL).alias(Column.POSITION_NAME), + ] + ) + + # Add to padding rows using same pattern as for players + for row in ball_rows_to_add.iter_rows(named=True): + base_row = { + col: row[col] + for col in keep_columns + + group_by_columns + + [Column.POSITION_NAME] + + self.global_feature_cols + if col in row + } + padding_rows.append(base_row) + + if len(padding_rows) == 0: + return df + + padding_df = pl.DataFrame(padding_rows) + + schema = df.schema + + padding_df = padding_df.with_columns( + [create_default_expression(col, schema[col]) for col in empty_columns] + + [ + pl.lit(None).cast(schema[col]).alias(col) + for col in user_defined_columns + ] + ) + padding_df = padding_df.with_columns( + [pl.col(col).cast(df.schema[col]).alias(col) for col in group_by_columns] + ) + + padding_df = padding_df.join( + ( + df.unique(group_by_columns).select( + group_by_columns + self.global_feature_cols + ) + ), + on=group_by_columns, + how="left", + ) + + padding_df = padding_df.with_columns( + [ + pl.col(col_name).cast(df.schema[col_name]).alias(col_name) + for col_name in df.columns + ] + ).select(df.columns) + + result = pl.concat([df, padding_df], how="vertical") + + total_frames = result.select(Group.BY_FRAME).unique().height + + frame_completeness = ( + result.group_by(Group.BY_FRAME) + .agg( + [ + (pl.col(Column.TEAM_ID).eq(Constant.BALL).sum() == 1).alias( + "has_ball" + ), + ( + pl.col(Column.TEAM_ID) + .eq(pl.col(Column.BALL_OWNING_TEAM_ID)) + .sum() + == 11 + ).alias("has_owning_team"), + ( + ( + ~pl.col(Column.TEAM_ID).eq(Constant.BALL) + & ~pl.col(Column.TEAM_ID).eq( + pl.col(Column.BALL_OWNING_TEAM_ID) + ) + ).sum() + == 11 + ).alias("has_other_team"), + ] + ) + .filter( + pl.col("has_ball") + & pl.col("has_owning_team") + & pl.col("has_other_team") + ) + ) + + complete_frames = frame_completeness.height + + dropped_frames = total_frames - complete_frames + if dropped_frames > 0 and self.verbose: + self.__warn_dropped_frames(dropped_frames, total_frames) + + return result.join(frame_completeness, on=Group.BY_FRAME, how="inner") + + @staticmethod + def __warn_dropped_frames(dropped_frames, total_frames): + import warnings + + warnings.warn( + f"""Setting pad=True drops frames that do not have at least 1 object for the attacking team, defending team or ball. + This operation dropped {dropped_frames} incomplete frames out of {total_frames} total frames ({(dropped_frames/total_frames)*100:.2f}%) + """ + ) + + def _apply_graph_settings(self): + return GraphSettingsPolars( + home_team_id=str(self._kloppy_settings.home_team_id), + away_team_id=str(self._kloppy_settings.away_team_id), + orientation=self._kloppy_settings.orientation, + pitch_dimensions=self.pitch_dimensions, + max_player_speed=self.settings.max_player_speed, + max_ball_speed=self.settings.max_ball_speed, + max_player_acceleration=self.settings.max_player_acceleration, + max_ball_acceleration=self.settings.max_ball_acceleration, self_loop_ball=self.self_loop_ball, adjacency_matrix_connect_type=self.adjacency_matrix_connect_type, adjacency_matrix_type=self.adjacency_matrix_type, label_type=self.label_type, - infer_ball_ownership=self.infer_ball_ownership, - infer_goalkeepers=self.infer_goalkeepers, defending_team_node_value=self.defending_team_node_value, non_potential_receiver_node_value=self.non_potential_receiver_node_value, random_seed=self.random_seed, @@ -120,178 +412,302 @@ def __post_init__(self): verbose=self.verbose, ) - if isinstance(self.dataset, TrackingDataset): - if not self.dataset.metadata.flags & DatasetFlag.BALL_OWNING_TEAM: - to_orientation = Orientation.STATIC_HOME_AWAY - else: - to_orientation = Orientation.BALL_OWNING_TEAM - - self.dataset = DatasetTransformer.transform_dataset( - dataset=self.dataset, - to_orientation=to_orientation, - to_coordinate_system=SecondSpectrumCoordinateSystem( - pitch_length=self.dataset.metadata.pitch_dimensions.pitch_length, - pitch_width=self.dataset.metadata.pitch_dimensions.pitch_width, - ), - ) - self.orientation = self.dataset.metadata.orientation + def _sport_specific_checks(self): + if not isinstance(self.label_column, str): + raise Exception("'label_col' should be of type string (str)") - self.settings.pitch_dimensions = self.dataset.metadata.pitch_dimensions + if not isinstance(self.graph_id_column, str): + raise Exception("'graph_id_col' should be of type string (str)") - def _sport_specific_checks(self): - if not self.labels and not self.prediction: + if not isinstance(self.chunk_size, int): + raise Exception("chunk_size should be of type integer (int)") + + if not self.label_column in self.dataset.columns and not self.prediction: raise Exception( - "Please specify 'labels' or set 'prediction=True' if you want to use the converted dataset to make predictions on." + "Please specify a 'label_col' and add that column to your 'dataset' or set 'prediction=True' if you want to use the converted dataset to make predictions on." ) - if self.graph_id is not None and self.graph_ids: - raise Exception("Please set either 'graph_id' or 'graph_ids', not both...") + if not self.label_column in self.dataset.columns and self.prediction: + self.dataset = self.dataset.with_columns( + pl.lit(None).alias(self.label_column) + ) - if self.graph_ids: - if not self.prediction: - if not set(list(self.labels.keys())) == set( - list(self.graph_ids.keys()) - ): - raise KeyMismatchException( - "When 'graph_id' is of type dict it needs to have the exact same keys as 'labels'..." - ) - if not self.graph_ids and self.prediction: - self.graph_ids = {x.frame_id: x.frame_id for x in self.dataset} + if not self.graph_id_column in self.dataset.columns: + raise Exception( + "Please specify a 'graph_id_col' and add that column to your 'dataset' ..." + ) - if self.labels and not isinstance(self.labels, dict): - raise Exception("'labels' should be of type dictionary (dict)") + if self.non_potential_receiver_node_value and not isinstance( + self.non_potential_receiver_node_value, float + ): + raise Exception( + "'non_potential_receiver_node_value' should be of type float" + ) - if self.graph_id and not isinstance(self.graph_id, (str, int, dict)): - raise Exception("'graph_id_col' should be of type {str, int, dict}") + @property + def _exprs_variables(self): + exprs_variables = [ + Column.X, + Column.Y, + Column.Z, + Column.SPEED, + Column.VX, + Column.VY, + Column.VZ, + Column.ACCELERATION, + Column.AX, + Column.AY, + Column.AZ, + Column.TEAM_ID, + Column.POSITION_NAME, + Column.BALL_OWNING_TEAM_ID, + Column.IS_BALL_CARRIER, + self.graph_id_column, + self.label_column, + ] + exprs = ( + exprs_variables + self.global_feature_cols + self.additional_feature_cols + ) + return exprs - if self.graph_ids and not isinstance(self.graph_ids, dict): - raise Exception("chunk_size should be of type dictionary (dict") + @property + def default_node_feature_funcs(self) -> list: + return [ + x_normed, + y_normed, + speeds_normed, + velocity_components_2d_normed, + distance_to_goal_normed, + distance_to_ball_normed, + is_possession_team, + is_gk, + is_ball, + angle_to_goal_components_2d_normed, + angle_to_ball_components_2d_normed, + is_ball_carrier, + ] - if not isinstance(self.infer_goalkeepers, bool): - raise Exception("'infer_goalkeepers' should be of type boolean (bool)") + @property + def default_edge_feature_funcs(self) -> list: + return [ + distances_between_players_normed, + speed_difference_normed, + angle_between_players_normed, + velocity_difference_normed, + ] - if not isinstance(self.infer_ball_ownership, bool): - raise Exception("'infer_ball_ownership' should be of type boolean (bool)") + def __add_additional_kwargs(self, d): + d["ball_id"] = Constant.BALL + d["possession_team_id"] = d[Column.BALL_OWNING_TEAM_ID][0] + d["is_gk"] = np.where( + d[Column.POSITION_NAME] == self.settings.goalkeeper_id, True, False + ) + d["position"] = np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1) + d["velocity"] = np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1) - if self.boundary_correction and not isinstance(self.boundary_correction, float): - raise Exception("'boundary_correction' should be of type float") + if len(np.where(d["team_id"] == d["ball_id"])[0]) >= 1: + ball_index = np.where(d["team_id"] == d["ball_id"])[0] + ball_position = d["position"][ball_index][0] + else: + ball_position = np.asarray([0.0, 0.0, 0.0]) + ball_index = 0 - if not isinstance(self.max_player_speed, (float, int)): - raise Exception("'max_player_speed' should be of type float or int") + ball_carriers = np.where(d[Column.IS_BALL_CARRIER] == True)[0] + if len(ball_carriers) == 0: + ball_carrier_idx = None + else: + ball_carrier_idx = ball_carriers[0] - if not isinstance(self.max_ball_speed, (float, int)): - raise Exception("'max_ball_speed' should be of type float or int") + d["ball_position"] = ball_position - # if not isinstance(self.max_player_acceleration, (float, int)): - # raise Exception("'max_player_acceleration' should be of type float or int") + d["ball_idx"] = ball_index + d["ball_carrier_idx"] = ball_carrier_idx + return d - # if not isinstance(self.max_ball_acceleration, (float, int)): - # raise Exception("'max_ball_acceleration' should be of type float or int") + def _compute(self, args: List[pl.Series]) -> dict: + frame_data: dict = { + col: args[i].to_numpy() for i, col in enumerate(self._exprs_variables) + } + frame_data = self.__add_additional_kwargs(frame_data) - if self.ball_carrier_treshold and not isinstance( - self.ball_carrier_treshold, float + if not np.all( + frame_data[self.graph_id_column] == frame_data[self.graph_id_column][0] ): - raise Exception("'ball_carrier_treshold' should be of type float") + raise ValueError( + "graph_id selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..." + ) - if self.non_potential_receiver_node_value and not isinstance( - self.non_potential_receiver_node_value, float + if not self.prediction and not np.all( + frame_data[self.label_column] == frame_data[self.label_column][0] ): - raise Exception( - "'non_potential_receiver_node_value' should be of type float" + raise ValueError( + """Label selection contains multiple different values for a single selection (group by) of game_id and frame_id, + make sure this is not the case. Each group can only have 1 label.""" ) - def _convert(self, frame: Frame): - data = DefaultTrackingModel( - frame, - fps=self.dataset.metadata.frame_rate, - infer_ball_ownership=self.settings.infer_ball_ownership, - infer_goalkeepers=self.settings.infer_goalkeepers, - ball_carrier_treshold=self.settings.ball_carrier_treshold, - orientation=self.orientation, - verbose=self.settings.verbose, - pad_n_players=( - None if not self.settings.pad else self.settings.pad_settings.n_players - ), + adjacency_matrix = compute_adjacency_matrix( + settings=self.settings, **frame_data + ) + edge_features, self._edge_feature_dims = compute_edge_features( + adjacency_matrix=adjacency_matrix, + funcs=self.edge_feature_funcs, + opts=self.feature_opts, + settings=self.settings, + **frame_data, ) - if isinstance(frame, Frame): - if not self.prediction: - label = self.labels.get(frame.frame_id, None) - else: - label = -1 - - graph_id = None - if ( - self.graph_id is None and not self.graph_ids - ): # technically graph_id can be 0 - graph_id = None - elif self.graph_ids: - graph_id = self.graph_ids.get(frame.frame_id, None) - elif self.graph_id: - graph_id = self.graph_id - else: - raise NotImplementedError() + node_features, self._node_feature_dims = compute_node_features( + funcs=self.node_feature_funcs, + opts=self.feature_opts, + settings=self.settings, + **frame_data, + ) - if not self.prediction and label is None: - if self.settings.verbose: - warn( - f"""No label for frame={frame.frame_id} in 'labels'...""", - NoLabelWarning, - ) - frame_id = frame.frame_id - else: - raise NotImplementedError( - """Format is not supported, should be TrackingDataset (Kloppy)""" - ) + if self.global_feature_cols: + failed = [ + col + for col in self.global_feature_cols + if not np.all(frame_data[col] == frame_data[col][0]) + ] + if failed: + raise ValueError( + f"""graph_feature_cols contains multiple different values for a group in the groupby ({Group.BY_FRAME}) selection for the columns {failed}. Make sure each group has the same values per individual column.""" + ) - return data, label, frame_id, graph_id + global_features = ( + np.asarray([frame_data[col] for col in self.global_feature_cols]).T[0] + if self.global_feature_cols + else None + ) + for col in self.global_feature_cols: + self._node_feature_dims[col] = 1 - def to_graph_frames(self) -> dict: - if not self.graph_frames: - from tqdm import tqdm + node_features = add_global_features( + node_features=node_features, + global_features=global_features, + global_feature_type=self.global_feature_type, + **frame_data, + ) + return { + "e": pl.Series( + [edge_features.tolist()], dtype=pl.List(pl.List(pl.Float64)) + ), + "x": pl.Series( + [node_features.tolist()], dtype=pl.List(pl.List(pl.Float64)) + ), + "a": pl.Series( + [adjacency_matrix.tolist()], dtype=pl.List(pl.List(pl.Int32)) + ), + "e_shape_0": edge_features.shape[0], + "e_shape_1": edge_features.shape[1], + "x_shape_0": node_features.shape[0], + "x_shape_1": node_features.shape[1], + "a_shape_0": adjacency_matrix.shape[0], + "a_shape_1": adjacency_matrix.shape[1], + self.graph_id_column: frame_data[self.graph_id_column][0], + self.label_column: frame_data[self.label_column][0], + } - if not self.dataset: - raise MissingDatasetError( - "Please specificy a 'dataset' a Kloppy TrackingDataset (see README)" - ) + @property + def return_dtypes(self): + return pl.Struct( + { + "e": pl.List(pl.List(pl.Float64)), + "x": pl.List(pl.List(pl.Float64)), + "a": pl.List(pl.List(pl.Float64)), + "e_shape_0": pl.Int64, + "e_shape_1": pl.Int64, + "x_shape_0": pl.Int64, + "x_shape_1": pl.Int64, + "a_shape_0": pl.Int64, + "a_shape_1": pl.Int64, + self.graph_id_column: pl.String, + self.label_column: pl.Int64, + } + ) - if isinstance(self.dataset, TrackingDataset): - if not self.labels and not self.prediction: - raise MissingLabelsError( - "Please specificy 'labels' of type Dict when using Kloppy" - ) - else: - raise IncorrectDatasetTypeError( - "dataset should be of type TrackingDataset" - ) + def _convert(self): + # Group and aggregate in one step + return ( + self.dataset.group_by(Group.BY_FRAME, maintain_order=True) + .agg( + pl.map_groups( + exprs=self._exprs_variables, + function=self._compute, + return_dtype=self.return_dtypes, + ).alias("result_dict") + ) + .with_columns( + [ + *[ + pl.col("result_dict").struct.field(f).alias(f) + for f in [ + "a", + "e", + "x", + self.graph_id_column, + self.label_column, + ] + ], + *[ + pl.col("result_dict") + .struct.field(f"{m}_shape_{i}") + .alias(f"{m}_shape_{i}") + for m in ["a", "e", "x"] + for i in [0, 1] + ], + ] + ) + .drop("result_dict") + ) - self.graph_frames = list() - - for frame in tqdm(self.dataset, desc="Processing frames"): - data, label, frame_id, graph_id = self._convert(frame) - if data.home_players and data.away_players: - try: - gnn_frame = GraphFrame( - frame_id=frame_id, - data=data, - label=label, - graph_id=graph_id, - settings=self.settings, + def to_graph_frames(self) -> List[dict]: + def process_chunk(chunk: pl.DataFrame) -> List[dict]: + return [ + { + "a": make_sparse( + reshape_from_size( + chunk["a"][i], chunk["a_shape_0"][i], chunk["a_shape_1"][i] ) - if gnn_frame.graph_data: - self.graph_frames.append(gnn_frame) - except QhullError: - pass + ), + "x": reshape_from_size( + chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i] + ), + "e": reshape_from_size( + chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i] + ), + "y": np.asarray([chunk[self.label_column][i]]), + "id": chunk[self.graph_id_column][i], + } + for i in range(len(chunk)) + ] + graph_df = self._convert() + self.graph_frames = [ + graph + for chunk in graph_df.lazy() + .collect(engine="gpu") + .iter_slices(self.chunk_size) + for graph in process_chunk(chunk) + ] return self.graph_frames def to_spektral_graphs(self) -> List[Graph]: if not self.graph_frames: self.to_graph_frames() - return [g.to_spektral_graph() for g in self.graph_frames] + return [ + Graph( + x=d["x"], + a=d["a"], + e=d["e"], + y=d["y"], + id=d["id"], + ) + for d in self.graph_frames + ] - def to_pickle(self, file_path: str) -> None: + def to_pickle(self, file_path: str, verbose: bool = False) -> None: """ We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data) @@ -304,6 +720,9 @@ def to_pickle(self, file_path: str) -> None: if not self.graph_frames: self.to_graph_frames() + if verbose: + print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...") + import pickle import gzip from pathlib import Path @@ -314,5 +733,418 @@ def to_pickle(self, file_path: str) -> None: directories.mkdir(parents=True, exist_ok=True) with gzip.open(file_path, "wb") as file: - data = [x.graph_data for x in self.graph_frames] - pickle.dump(data, file) + pickle.dump(self.graph_frames, file) + + def plot( + self, + file_path: str, + fps: int = None, + timestamp: pl.duration = None, + end_timestamp: pl.duration = None, + period_id: int = None, + team_color_a: str = "#CD0E61", + team_color_b: str = "#0066CC", + ball_color: str = "black", + sort: bool = True, + color_by: Literal["ball_owning", "static_home_away"] = "ball_owning", + ): + """ + Plot tracking data as a static image or video file. + + This method visualizes tracking data for players and the ball. It can generate either: + - A single PNG image (if either fps or end_timestamp is None, or both are None) + - An MP4 video (if both fps and end_timestamp are provided) + + Parameters + ---------- + file_path : str + The output path where the PNG or MP4 file will be saved + fps : int, optional + Frames per second for video output. If None, a static image is generated + timestamp : pl.duration, optional + The starting timestamp to plot. If None, starts from the beginning of available data + end_timestamp : pl.duration, optional + The ending timestamp for video output. If None, a static image is generated + period_id : int, optional + ID of the match period to visualize. If None, all periods are included + team_color_a : str, default "#CD0E61" + Hex color code for Team A visualization + team_color_b : str, default "#0066CC" + Hex color code for Team B visualization + ball_color : str, default "black" + Color for ball visualization + color_by : Literal["ball_owning", "static_home_away"], default "ball_owning" + Method for coloring the teams: + - "ball_owning": Colors teams based on ball possession + - "static_home_away": Uses static colors for home and away teams + + Returns + ------- + None + The function saves the output file to the specified file_path but doesn't return any value + + Notes + ----- + Output file type is determined by parameters: + - PNG: Generated when either fps or end_timestamp is None, or both are None + - MP4: Generated when both fps and end_timestamp are provided + + Raises + ------ + ValueError + If file extension doesn't match the parameters provided (e.g., .mp4 extension + but missing fps or end_timestamp, or .png extension with both fps and end_timestamp) + """ + try: + import matplotlib.animation as animation + import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpec + except ImportError: + raise ImportError( + "Seems like you don't have matplotlib installed. Please" + " install it using: pip install matplotlib" + ) + + if (fps is None and end_timestamp is not None) or ( + fps is not None and end_timestamp is None + ): + raise ValueError( + "Both 'fps' and 'end_timestamp' must be provided together to generate a video. " + ) + + # Determine the output type based on parameters + generate_video = fps is not None and end_timestamp is not None + + # Get file extension if it exists + path = pathlib.Path(file_path) + file_extension = path.suffix.lower() if path.suffix else "" + + # If no extension, add the appropriate one based on parameters + if not file_extension: + suffix = ".mp4" if generate_video else ".png" + file_path = str(path.with_suffix(suffix)) + + # Otherwise, validate that the extension matches the parameters + else: + if generate_video and file_extension != ".mp4": + raise ValueError( + f"Parameters fps and end_timestamp indicate video output, " + f"but file extension is '{file_extension}'. Use '.mp4' extension for video output." + ) + elif not generate_video and file_extension == ".mp4": + raise ValueError( + "To generate an MP4 video, both 'fps' and 'end_timestamp' must be provided. " + "For static image output, use a '.png' extension." + ) + elif not generate_video and file_extension != ".png": + raise ValueError( + f"For static image output, use '.png' extension instead of '{file_extension}'." + ) + + self._team_color_a = team_color_a + self._team_color_b = team_color_b + self._ball_color = ball_color + self._color_by = color_by + + if period_id is not None and not isinstance(period_id, int): + raise TypeError("period_id should be of type integer") + + if all(x is None for x in [timestamp, end_timestamp, period_id]): + # No filters specified, use the entire dataset + df = self.dataset + elif timestamp is not None and period_id is not None: + if end_timestamp is not None: + # Both timestamp and end_timestamp provided - filter for a range + df = self.dataset.filter( + (pl.col(Column.TIMESTAMP).is_between(timestamp, end_timestamp)) + & (pl.col(Column.PERIOD_ID) == period_id) + ) + else: + # Only timestamp provided (no end_timestamp) - filter for specific timestamp + df = self.dataset.filter( + (pl.col(Column.TIMESTAMP) == timestamp) + & (pl.col(Column.PERIOD_ID) == period_id) + ) + # Handle the case where a single timestamp has multiple frame_ids + df = ( + df.with_columns( + pl.col(Column.FRAME_ID) + .rank(method="min") + .over(Column.TIMESTAMP) + .alias("frame_rank") + ) + # Keep only rows where the frame has rank = 1 (first frame for each timestamp) + .filter(pl.col("frame_rank") == 1).drop("frame_rank") + ) + else: + raise ValueError( + "Please specify both timestamp and period_id, or specify all of timestamp, end_timestamp, and period_id, or none of them." + ) + + if df.is_empty(): + raise ValueError("Selection is empty, please try different timestamp(s)") + + def plot_graph(): + import matplotlib.pyplot as plt + + # Plot node features in top-left + ax1 = self._fig.add_subplot(self._gs[0, 0]) + ax1.imshow(self._graph.x, aspect="auto", cmap="YlOrRd") + ax1.set_xlabel(f"Node Features {self._graph.x.shape}") + + # Set y labels to integers + num_rows = self._graph.x.shape[0] + ax1.set_yticks(range(num_rows)) + ax1.set_yticklabels([str(i) for i in range(num_rows)]) + + node_feature_yticklabels = feature_ticklabels(self._node_feature_dims) + ax1.xaxis.set_ticks_position("top") + ax1.set_xticks(range(len(node_feature_yticklabels))) + ax1.set_xticklabels(node_feature_yticklabels, rotation=45, ha="left") + + # Plot ajacency matrix in bottom-left + ax2 = self._fig.add_subplot(self._gs[1, 0]) + ax2.imshow(self._graph.a.toarray(), aspect="auto", cmap="YlOrRd") + ax2.set_xlabel(f"Adjacency Matrix {self._graph.a.shape}") + + # Set both x and y labels to integers + num_rows_a = self._graph.a.toarray().shape[0] + num_cols_a = self._graph.a.toarray().shape[1] + + ax2.set_yticks(range(num_rows_a)) + ax2.set_yticklabels([str(i) for i in range(num_rows_a)]) + ax2.xaxis.set_ticks_position("top") + ax2.set_xticks(range(num_cols_a)) + ax2.set_xticklabels([str(i) for i in range(num_cols_a)]) + + # Plot Edge Features on the right (spanning both rows) + ax3 = self._fig.add_subplot(self._gs[:, 1]) + + _, size_a = non_zeros(self._graph.a.toarray()[0 : self._ball_carrier_idx]) + ball_carrier_edge_idx, num_rows_e = non_zeros( + np.asarray( + [list(x) for x in self._graph.a.toarray()][self._ball_carrier_idx] + ) + ) + + im3 = ax3.imshow( + self._graph.e[size_a : num_rows_e + size_a, :], + aspect="auto", + cmap="YlOrRd", + ) + + ax3.set_yticks(range(num_rows_e)) + ax3.set_yticklabels(list(ball_carrier_edge_idx[0]), fontsize=18) + ax3.set_xlabel(f"Edge Features {self._graph.e.shape}") + + labels = ax3.get_yticklabels() + if self._ball_carrier_idx in ball_carrier_edge_idx[0]: + idx_position = list(ball_carrier_edge_idx[0]).index( + self._ball_carrier_idx + ) + # Modify just that specific label + labels[idx_position].set_color(self._ball_carrier_color) + labels[idx_position].set_fontweight("bold") + # Set the modified labels back + ax3.set_yticklabels(labels) + + # Set x labels to edge function names at the top, rotated 45 degrees + edge_feature_xticklabels = feature_ticklabels(self._edge_feature_dims) + ax3.xaxis.set_ticks_position("top") + ax3.set_xticks(range(len(edge_feature_xticklabels))) + ax3.set_xticklabels(edge_feature_xticklabels, rotation=45, ha="left") + + plt.colorbar(im3, ax=ax3, fraction=0.1, pad=0.2) + + def plot_vertical_pitch(frame_data: pl.DataFrame): + try: + from mplsoccer import VerticalPitch + except ImportError: + raise ImportError( + "Seems like you don't have mplsoccer installed. Please" + " install it using: pip install mplsoccer" + ) + + ax4 = self._fig.add_subplot(self._gs[:, 2]) + pitch = VerticalPitch( + pitch_type="secondspectrum", + pitch_length=self.pitch_dimensions.pitch_length, + pitch_width=self.pitch_dimensions.pitch_width, + pitch_color="#ffffff", + pad_top=-0.05, + ) + pitch.draw(ax=ax4) + player_and_ball(frame_data=frame_data, ax=ax4) + direction_of_play_arrow(ax=ax4) + + def feature_ticklabels(feature_dims): + _feature_ticklabels = [] + for key, value in feature_dims.items(): + if value == 1: + _feature_ticklabels.append(key) + else: + _feature_ticklabels.extend([key] + [None] * (value - 1)) + return _feature_ticklabels + + def direction_of_play_arrow(ax): + arrow_x = -30 + arrow_y = -7.5 + arrow_dx = 0 + arrow_dy = 15 + + if self.settings.orientation == Orientation.STATIC_HOME_AWAY: + if self._ball_owning_team_id != self.settings.home_team_id: + arrow_y = arrow_y * -1 + arrow_dy = arrow_dy * -1 + elif self.settings.orientation == Orientation.BALL_OWNING_TEAM: + pass + else: + raise ValueError(f"Unsupported orientation {self.settings.orientation}") + + # Create the arrow to indicate direction of play + ax.arrow( + arrow_x, + arrow_y, + arrow_dx, + arrow_dy, + head_width=3, + head_length=2, + fc="#c2c2c2", + ec="#c2c2c2", + width=0.5, + length_includes_head=True, + zorder=1, + ) + + def player_and_ball(frame_data, ax): + if self._color_by == "ball_owning": + team_id = self._ball_owning_team_id + elif self._color_by == "static_home_away": + team_id = self.settings.home_team_id + else: + raise ValueError(f"Unsupported color_by {self._color_by}") + + self._ball_carrier_color = None + + for i, r in enumerate(frame_data.iter_rows(named=True)): + v, vy, vx, y, x = ( + r[Column.SPEED], + r[Column.VX], + r[Column.VY], + r[Column.X], + r[Column.Y], + ) + is_ball = True if r[Column.TEAM_ID] == self.settings.ball_id else False + + if not is_ball: + if team_id is None: + team_id = r[Column.TEAM_ID] + + color = ( + self._team_color_a + if r[Column.TEAM_ID] == team_id + else self._team_color_b + ) + + if r[Column.IS_BALL_CARRIER] == True: + self._ball_carrier_color = color + + ax.scatter(x, y, color=color, s=450) + + if v > 1.0: + ax.annotate( + "", + xy=(x + vx, y + vy), + xytext=(x, y), + arrowprops=dict(arrowstyle="->", color=color, lw=3), + ) + + else: + ax.scatter(x, y, color=self._ball_color, s=250, zorder=10) + # # Text with white border + text = ax.text( + x + (-1.2 if is_ball else 0.0), + y + (-1.2 if is_ball else 0.0), + i, + color=self._ball_color if is_ball else color, + fontsize=12, + ha="center", + va="center", + zorder=15 if is_ball else 5, + ) + + import matplotlib.patheffects as path_effects + + text.set_path_effects( + [ + path_effects.Stroke(linewidth=6, foreground="white"), + path_effects.Normal(), + ] + ) + ax.set_xlabel(f"Label: {frame_data['label'][0]}", fontsize=22) + + def frame_plot(self, frame_data): + self._gs = GridSpec( + 2, + 3, + width_ratios=[2, 1, 3], + height_ratios=[1, 1], + wspace=0.1, + hspace=0.06, + left=0.05, + right=1.0, + bottom=0.05, + ) + + # Process the current frame + features = self._compute([frame_data[col] for col in self._exprs_variables]) + a = make_sparse( + reshape_from_size( + features["a"], features["a_shape_0"], features["a_shape_1"] + ) + ) + x = reshape_from_size( + features["x"], features["x_shape_0"], features["x_shape_1"] + ) + e = reshape_from_size( + features["e"], features["e_shape_0"], features["e_shape_1"] + ) + y = np.asarray([features[self.label_column]]) + + self._graph = Graph( + a=a, + x=x, + e=e, + y=y, + ) + + self._ball_carrier_idx = np.where( + frame_data[Column.IS_BALL_CARRIER] == True + )[0][0] + self._ball_owning_team_id = list(frame_data[Column.BALL_OWNING_TEAM_ID])[0] + + plot_vertical_pitch(frame_data) + plot_graph() + + plt.tight_layout() + + self._fig = plt.figure(figsize=(25, 18)) + self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05) + + if sort: + df = self._sort(df) + + if generate_video: + writer = animation.FFMpegWriter(fps=fps, bitrate=1800) + + with writer.saving(self._fig, file_path, dpi=300): + for group_id, frame_data in df.group_by( + Group.BY_FRAME, maintain_order=True + ): + self._fig.clear() + frame_plot(self, frame_data) + writer.grab_frame() + + else: + frame_plot(self, frame_data=df) + plt.savefig(file_path, dpi=300) diff --git a/unravel/soccer/graphs/graph_converter_pl.py b/unravel/soccer/graphs/graph_converter_pl.py deleted file mode 100644 index 3304b4b8..00000000 --- a/unravel/soccer/graphs/graph_converter_pl.py +++ /dev/null @@ -1,1162 +0,0 @@ -import logging -import sys - -from dataclasses import dataclass - -from typing import List, Union, Dict, Literal, Any, Optional, Callable - -import inspect - -import pathlib - -from kloppy.domain import MetricPitchDimensions, Orientation - -from spektral.data import Graph - -from .graph_settings_pl import GraphSettingsPolars -from ..dataset.kloppy_polars import KloppyPolarsDataset, Column, Group, Constant -from .features import ( - compute_node_features, - add_global_features, - compute_adjacency_matrix, - compute_edge_features, -) - -from ...utils import * - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) -stdout_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(stdout_handler) - - -@dataclass(repr=True) -class SoccerGraphConverterPolars(DefaultGraphConverter): - """ - Converts our dataset TrackingDataset into an internal structure - - Attributes: - dataset (KloppyPolarsDataset): KloppyPolarsDataset created from a Kloppy dataset. - chunk_size (int): Determines how many Graphs get processed simultanously. - non_potential_receiver_node_value (float): Value between 0 and 1 to assign to the defing team players - global_feature_cols (list[str]): List of columns in the dataset that are Graph level features (e.g. team strength rating, win probabilities etc) - we want to add to our model. A list of column names corresponding to the Polars dataframe within KloppyPolarsDataset.data - that are graph level features. They should be joined to the KloppyPolarsDataset.data dataframe such that - each Group in the group_by has the same value per column. We take the first value of the group, and assign this as a - "graph level feature" to the ball node. - global_feature_type: A literal of type "ball" or "all". When set to "ball" the global features will be assigned to only the ball node, if set to "all" - the they will be assigned to every player and ball in the node features. - edge_feature_funcs: A list of functions (decorated with @graph_feature(is_custom, feature_type="edge")) - that take **kwargs as input and return a numpy array (dimensions should match expected (N,N) shape or tuple with multipe (N, N) numpy arrays). - node_feature_funcs: A list of functions (decorated with @graph_feature(is_custom, feature_type="node")) - that take **kwargs as input and return a numpy array (dimensions should match expected (N,) shape or (N, k) ) - additional_feature_cols: Column the user has added to the 'KloppyPolarsDataset.data' that are not to be added as global features, - but can now be accessed by edge_feature_funcs and node_feature_funcs through kwargs. - (e.g. if the user adds "height" for each player, as a column to the 'KloppyPolarsDataset.data' and - they want to use it to compute the height difference between all players as an edge feature they would - pass additional_feature_cols=["height"] and their custom edge feature function can now access kwargs['height']) - """ - - dataset: KloppyPolarsDataset = None - - chunk_size: int = 2_0000 - non_potential_receiver_node_value: float = 0.1 - - edge_feature_funcs: List[Callable[[Dict[str, Any]], np.ndarray]] = field( - repr=False, default_factory=list - ) - node_feature_funcs: List[Callable[[Dict[str, Any]], np.ndarray]] = field( - repr=False, default_factory=list - ) - - global_feature_cols: Optional[List[str]] = field(repr=False, default_factory=list) - global_feature_type: Literal["ball", "all"] = "ball" - - additional_feature_cols: Optional[List[str]] = field( - repr=False, default_factory=list - ) - - _edge_feature_dims: Dict[str, int] = field( - repr=False, default_factory=dict, init=False - ) - _node_feature_dims: Dict[str, int] = field( - repr=False, default_factory=dict, init=False - ) - - def __post_init__(self): - if not isinstance(self.dataset, KloppyPolarsDataset): - raise ValueError("dataset should be of type KloppyPolarsDataset...") - - self.pitch_dimensions: MetricPitchDimensions = ( - self.dataset.settings.pitch_dimensions - ) - self._kloppy_settings = self.dataset.settings - - self.label_column: str = ( - self.label_col if self.label_col is not None else self.dataset._label_column - ) - self.graph_id_column: str = ( - self.graph_id_col - if self.graph_id_col is not None - else self.dataset._graph_id_column - ) - - self.dataset = self.dataset.data - - if not self.edge_feature_funcs: - self.edge_feature_funcs = self.default_edge_feature_funcs - - self._verify_feature_funcs(self.edge_feature_funcs, feature_type="edge") - - if not self.node_feature_funcs: - self.node_feature_funcs = self.default_node_feature_funcs - - self._verify_feature_funcs(self.node_feature_funcs, feature_type="node") - - self._sport_specific_checks() - self.settings = self._apply_graph_settings() - - if self.pad: - self.dataset = self._apply_padding() - else: - self.dataset = self._remove_incomplete_frames() - - self._sample() - self._shuffle() - - def _sample(self): - if self.sample_rate is None: - return - else: - self.dataset = self.dataset.filter( - pl.col(Column.FRAME_ID) % (1.0 / self.sample_rate) == 0 - ) - - def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]): - for i, func in enumerate(funcs): - # Check if it has the attributes added by the decorator - if not hasattr(func, "feature_type"): - func_str = inspect.getsource(func).strip() - raise Exception( - f"Error processing feature function:\n" - f"{func.__name__} defined as:\n" - f"{func_str}\n\n" - "Function is missing the @graph_feature decorator. " - ) - - if func.feature_type != feature_type: - func_str = inspect.getsource(func).strip() - raise Exception( - f"Error processing feature function:\n" - f"{func.__name__} defined as:\n" - f"{func_str}\n\n" - "Function has an incorrect feature type edge features should be 'edge', node features should be 'node'. " - ) - - @staticmethod - def _sort(df): - sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - ( - (pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID)) - & (pl.col(Column.TEAM_ID) != Constant.BALL) - ).cast(int) - - df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)]) - return df - - def _shuffle(self): - if self.settings.random_seed is None or self.settings.random_seed == False: - self.dataset = self._sort(self.dataset) - if isinstance(self.settings.random_seed, int): - self.dataset = self.dataset.sample( - fraction=1.0, seed=self.settings.random_seed - ) - elif self.settings.random_seed == True: - self.dataset = self.dataset.sample(fraction=1.0) - else: - self.dataset = self._sort(self.dataset) - - def _remove_incomplete_frames(self) -> pl.DataFrame: - df = self.dataset - total_frames = len(df.unique(Group.BY_FRAME)) - - valid_frames = ( - df.group_by(Group.BY_FRAME) - .agg(pl.col(Column.TEAM_ID).n_unique().alias("unique_teams")) - .filter(pl.col("unique_teams") == 3) - .select(Group.BY_FRAME) - ) - dropped_frames = total_frames - len(valid_frames.unique(Group.BY_FRAME)) - if dropped_frames > 0 and self.verbose: - self.__warn_dropped_frames(dropped_frames, total_frames) - - return df.join(valid_frames, on=Group.BY_FRAME) - - def _apply_padding(self) -> pl.DataFrame: - df = self.dataset - - keep_columns = [ - Column.TIMESTAMP, - Column.BALL_STATE, - self.label_column, - self.graph_id_column, - ] - empty_columns = [ - Column.POSITION_NAME, - Column.OBJECT_ID, - Column.IS_BALL_CARRIER, - Column.X, - Column.Y, - Column.Z, - Column.VX, - Column.VY, - Column.VZ, - Column.SPEED, - Column.AX, - Column.AY, - Column.AZ, - Column.ACCELERATION, - ] - group_by_columns = [ - Column.GAME_ID, - Column.PERIOD_ID, - Column.FRAME_ID, - Column.TEAM_ID, - Column.BALL_OWNING_TEAM_ID, - ] - - user_defined_columns = [ - x - for x in df.columns - if x - not in keep_columns - + group_by_columns - + empty_columns - + self.global_feature_cols - ] - - counts = df.group_by(group_by_columns).agg( - pl.len().alias("count"), - *[ - pl.first(col).alias(col) - for col in keep_columns + self.global_feature_cols - ], - ) - - counts = counts.with_columns( - [ - pl.when(pl.col(Column.TEAM_ID) == Constant.BALL) - .then(1) - .when(pl.col(Column.TEAM_ID) == pl.col(Column.BALL_OWNING_TEAM_ID)) - .then(11) - .otherwise(11) - .alias("target_length") - ] - ) - - groups_to_pad = counts.filter( - pl.col("count") < pl.col("target_length") - ).with_columns((pl.col("target_length") - pl.col("count")).alias("repeats")) - - padding_rows = [] - # This is where we pad players (missing balls get skipped because of 'target_length') - for row in groups_to_pad.iter_rows(named=True): - base_row = { - col: row[col] - for col in keep_columns + group_by_columns + self.global_feature_cols - } - padding_rows.extend([base_row] * row["repeats"]) - - # Now check if there are frames without ball rows - # Get all unique frames - all_frames = df.select( - [ - Column.GAME_ID, - Column.PERIOD_ID, - Column.FRAME_ID, - Column.BALL_OWNING_TEAM_ID, - ] - + keep_columns - + self.global_feature_cols - ).unique() - - # Get frames that have ball rows - frames_with_ball = ( - df.filter(pl.col(Column.TEAM_ID) == Constant.BALL) - .select([Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID]) - .unique() - ) - - # Find frames missing ball rows - frames_missing_ball = all_frames.join( - frames_with_ball, - on=[Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID], - how="anti", - ) - - # Create a dataframe of ball rows to add with appropriate columns - if frames_missing_ball.height > 0: - # Create base rows for missing balls - ball_rows_to_add = frames_missing_ball.with_columns( - [ - pl.lit(Constant.BALL).alias(Column.TEAM_ID), - pl.lit(Constant.BALL).alias(Column.POSITION_NAME), - ] - ) - - # Add to padding rows using same pattern as for players - for row in ball_rows_to_add.iter_rows(named=True): - base_row = { - col: row[col] - for col in keep_columns - + group_by_columns - + [Column.POSITION_NAME] - + self.global_feature_cols - if col in row - } - padding_rows.append(base_row) - - if len(padding_rows) == 0: - return df - - padding_df = pl.DataFrame(padding_rows) - - schema = df.schema - - padding_df = padding_df.with_columns( - [create_default_expression(col, schema[col]) for col in empty_columns] - + [ - pl.lit(None).cast(schema[col]).alias(col) - for col in user_defined_columns - ] - ) - padding_df = padding_df.with_columns( - [pl.col(col).cast(df.schema[col]).alias(col) for col in group_by_columns] - ) - - padding_df = padding_df.join( - ( - df.unique(group_by_columns).select( - group_by_columns + self.global_feature_cols - ) - ), - on=group_by_columns, - how="left", - ) - - padding_df = padding_df.with_columns( - [ - pl.col(col_name).cast(df.schema[col_name]).alias(col_name) - for col_name in df.columns - ] - ).select(df.columns) - - result = pl.concat([df, padding_df], how="vertical") - - total_frames = result.select(Group.BY_FRAME).unique().height - - frame_completeness = ( - result.group_by(Group.BY_FRAME) - .agg( - [ - (pl.col(Column.TEAM_ID).eq(Constant.BALL).sum() == 1).alias( - "has_ball" - ), - ( - pl.col(Column.TEAM_ID) - .eq(pl.col(Column.BALL_OWNING_TEAM_ID)) - .sum() - == 11 - ).alias("has_owning_team"), - ( - ( - ~pl.col(Column.TEAM_ID).eq(Constant.BALL) - & ~pl.col(Column.TEAM_ID).eq( - pl.col(Column.BALL_OWNING_TEAM_ID) - ) - ).sum() - == 11 - ).alias("has_other_team"), - ] - ) - .filter( - pl.col("has_ball") - & pl.col("has_owning_team") - & pl.col("has_other_team") - ) - ) - - complete_frames = frame_completeness.height - - dropped_frames = total_frames - complete_frames - if dropped_frames > 0 and self.verbose: - self.__warn_dropped_frames(dropped_frames, total_frames) - - return result.join(frame_completeness, on=Group.BY_FRAME, how="inner") - - @staticmethod - def __warn_dropped_frames(dropped_frames, total_frames): - import warnings - - warnings.warn( - f"""Setting pad=True drops frames that do not have at least 1 object for the attacking team, defending team or ball. - This operation dropped {dropped_frames} incomplete frames out of {total_frames} total frames ({(dropped_frames/total_frames)*100:.2f}%) - """ - ) - - def _apply_graph_settings(self): - return GraphSettingsPolars( - home_team_id=str(self._kloppy_settings.home_team_id), - away_team_id=str(self._kloppy_settings.away_team_id), - orientation=self._kloppy_settings.orientation, - pitch_dimensions=self.pitch_dimensions, - max_player_speed=self.settings.max_player_speed, - max_ball_speed=self.settings.max_ball_speed, - max_player_acceleration=self.settings.max_player_acceleration, - max_ball_acceleration=self.settings.max_ball_acceleration, - self_loop_ball=self.self_loop_ball, - adjacency_matrix_connect_type=self.adjacency_matrix_connect_type, - adjacency_matrix_type=self.adjacency_matrix_type, - label_type=self.label_type, - defending_team_node_value=self.defending_team_node_value, - non_potential_receiver_node_value=self.non_potential_receiver_node_value, - random_seed=self.random_seed, - pad=self.pad, - verbose=self.verbose, - ) - - def _sport_specific_checks(self): - if not isinstance(self.label_column, str): - raise Exception("'label_col' should be of type string (str)") - - if not isinstance(self.graph_id_column, str): - raise Exception("'graph_id_col' should be of type string (str)") - - if not isinstance(self.chunk_size, int): - raise Exception("chunk_size should be of type integer (int)") - - if not self.label_column in self.dataset.columns and not self.prediction: - raise Exception( - "Please specify a 'label_col' and add that column to your 'dataset' or set 'prediction=True' if you want to use the converted dataset to make predictions on." - ) - - if not self.label_column in self.dataset.columns and self.prediction: - self.dataset = self.dataset.with_columns( - pl.lit(None).alias(self.label_column) - ) - - if not self.graph_id_column in self.dataset.columns: - raise Exception( - "Please specify a 'graph_id_col' and add that column to your 'dataset' ..." - ) - - if self.non_potential_receiver_node_value and not isinstance( - self.non_potential_receiver_node_value, float - ): - raise Exception( - "'non_potential_receiver_node_value' should be of type float" - ) - - @property - def _exprs_variables(self): - exprs_variables = [ - Column.X, - Column.Y, - Column.Z, - Column.SPEED, - Column.VX, - Column.VY, - Column.VZ, - Column.ACCELERATION, - Column.AX, - Column.AY, - Column.AZ, - Column.TEAM_ID, - Column.POSITION_NAME, - Column.BALL_OWNING_TEAM_ID, - Column.IS_BALL_CARRIER, - self.graph_id_column, - self.label_column, - ] - exprs = ( - exprs_variables + self.global_feature_cols + self.additional_feature_cols - ) - return exprs - - @property - def default_node_feature_funcs(self) -> list: - return [ - x_normed, - y_normed, - speeds_normed, - velocity_components_2d_normed, - distance_to_goal_normed, - distance_to_ball_normed, - is_possession_team, - is_gk, - is_ball, - angle_to_goal_components_2d_normed, - angle_to_ball_components_2d_normed, - is_ball_carrier, - ] - - @property - def default_edge_feature_funcs(self) -> list: - return [ - distances_between_players_normed, - speed_difference_normed, - angle_between_players_normed, - velocity_difference_normed, - ] - - def __add_additional_kwargs(self, d): - d["ball_id"] = Constant.BALL - d["possession_team_id"] = d[Column.BALL_OWNING_TEAM_ID][0] - d["is_gk"] = np.where( - d[Column.POSITION_NAME] == self.settings.goalkeeper_id, True, False - ) - d["position"] = np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1) - d["velocity"] = np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1) - - if len(np.where(d["team_id"] == d["ball_id"])[0]) >= 1: - ball_index = np.where(d["team_id"] == d["ball_id"])[0] - ball_position = d["position"][ball_index][0] - else: - ball_position = np.asarray([0.0, 0.0, 0.0]) - ball_index = 0 - - ball_carriers = np.where(d[Column.IS_BALL_CARRIER] == True)[0] - if len(ball_carriers) == 0: - ball_carrier_idx = None - else: - ball_carrier_idx = ball_carriers[0] - - d["ball_position"] = ball_position - - d["ball_idx"] = ball_index - d["ball_carrier_idx"] = ball_carrier_idx - return d - - def _compute(self, args: List[pl.Series]) -> dict: - frame_data: dict = { - col: args[i].to_numpy() for i, col in enumerate(self._exprs_variables) - } - frame_data = self.__add_additional_kwargs(frame_data) - - if not np.all( - frame_data[self.graph_id_column] == frame_data[self.graph_id_column][0] - ): - raise ValueError( - "graph_id selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..." - ) - - if not self.prediction and not np.all( - frame_data[self.label_column] == frame_data[self.label_column][0] - ): - raise ValueError( - """Label selection contains multiple different values for a single selection (group by) of game_id and frame_id, - make sure this is not the case. Each group can only have 1 label.""" - ) - - adjacency_matrix = compute_adjacency_matrix( - settings=self.settings, **frame_data - ) - edge_features, self._edge_feature_dims = compute_edge_features( - adjacency_matrix=adjacency_matrix, - funcs=self.edge_feature_funcs, - opts=self.feature_opts, - settings=self.settings, - **frame_data, - ) - - node_features, self._node_feature_dims = compute_node_features( - funcs=self.node_feature_funcs, - opts=self.feature_opts, - settings=self.settings, - **frame_data, - ) - - if self.global_feature_cols: - failed = [ - col - for col in self.global_feature_cols - if not np.all(frame_data[col] == frame_data[col][0]) - ] - if failed: - raise ValueError( - f"""graph_feature_cols contains multiple different values for a group in the groupby ({Group.BY_FRAME}) selection for the columns {failed}. Make sure each group has the same values per individual column.""" - ) - - global_features = ( - np.asarray([frame_data[col] for col in self.global_feature_cols]).T[0] - if self.global_feature_cols - else None - ) - for col in self.global_feature_cols: - self._node_feature_dims[col] = 1 - - node_features = add_global_features( - node_features=node_features, - global_features=global_features, - global_feature_type=self.global_feature_type, - **frame_data, - ) - return { - "e": pl.Series( - [edge_features.tolist()], dtype=pl.List(pl.List(pl.Float64)) - ), - "x": pl.Series( - [node_features.tolist()], dtype=pl.List(pl.List(pl.Float64)) - ), - "a": pl.Series( - [adjacency_matrix.tolist()], dtype=pl.List(pl.List(pl.Int32)) - ), - "e_shape_0": edge_features.shape[0], - "e_shape_1": edge_features.shape[1], - "x_shape_0": node_features.shape[0], - "x_shape_1": node_features.shape[1], - "a_shape_0": adjacency_matrix.shape[0], - "a_shape_1": adjacency_matrix.shape[1], - self.graph_id_column: frame_data[self.graph_id_column][0], - self.label_column: frame_data[self.label_column][0], - } - - @property - def return_dtypes(self): - return pl.Struct( - { - "e": pl.List(pl.List(pl.Float64)), - "x": pl.List(pl.List(pl.Float64)), - "a": pl.List(pl.List(pl.Float64)), - "e_shape_0": pl.Int64, - "e_shape_1": pl.Int64, - "x_shape_0": pl.Int64, - "x_shape_1": pl.Int64, - "a_shape_0": pl.Int64, - "a_shape_1": pl.Int64, - self.graph_id_column: pl.String, - self.label_column: pl.Int64, - } - ) - - def _convert(self): - # Group and aggregate in one step - return ( - self.dataset.group_by(Group.BY_FRAME, maintain_order=True) - .agg( - pl.map_groups( - exprs=self._exprs_variables, - function=self._compute, - return_dtype=self.return_dtypes, - ).alias("result_dict") - ) - .with_columns( - [ - *[ - pl.col("result_dict").struct.field(f).alias(f) - for f in [ - "a", - "e", - "x", - self.graph_id_column, - self.label_column, - ] - ], - *[ - pl.col("result_dict") - .struct.field(f"{m}_shape_{i}") - .alias(f"{m}_shape_{i}") - for m in ["a", "e", "x"] - for i in [0, 1] - ], - ] - ) - .drop("result_dict") - ) - - def to_graph_frames(self) -> List[dict]: - def process_chunk(chunk: pl.DataFrame) -> List[dict]: - return [ - { - "a": make_sparse( - reshape_from_size( - chunk["a"][i], chunk["a_shape_0"][i], chunk["a_shape_1"][i] - ) - ), - "x": reshape_from_size( - chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i] - ), - "e": reshape_from_size( - chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i] - ), - "y": np.asarray([chunk[self.label_column][i]]), - "id": chunk[self.graph_id_column][i], - } - for i in range(len(chunk)) - ] - - graph_df = self._convert() - self.graph_frames = [ - graph - for chunk in graph_df.lazy() - .collect(engine="gpu") - .iter_slices(self.chunk_size) - for graph in process_chunk(chunk) - ] - return self.graph_frames - - def to_spektral_graphs(self) -> List[Graph]: - if not self.graph_frames: - self.to_graph_frames() - - return [ - Graph( - x=d["x"], - a=d["a"], - e=d["e"], - y=d["y"], - id=d["id"], - ) - for d in self.graph_frames - ] - - def to_pickle(self, file_path: str, verbose: bool = False) -> None: - """ - We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y - To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data) - """ - if not file_path.endswith("pickle.gz"): - raise ValueError( - "Only compressed pickle files of type 'some_file_name.pickle.gz' are supported..." - ) - - if not self.graph_frames: - self.to_graph_frames() - - if verbose: - print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...") - - import pickle - import gzip - from pathlib import Path - - path = Path(file_path) - - directories = path.parent - directories.mkdir(parents=True, exist_ok=True) - - with gzip.open(file_path, "wb") as file: - pickle.dump(self.graph_frames, file) - - def plot( - self, - file_path: str, - fps: int = None, - timestamp: pl.duration = None, - end_timestamp: pl.duration = None, - period_id: int = None, - team_color_a: str = "#CD0E61", - team_color_b: str = "#0066CC", - ball_color: str = "black", - sort: bool = True, - color_by: Literal["ball_owning", "static_home_away"] = "ball_owning", - ): - """ - Plot tracking data as a static image or video file. - - This method visualizes tracking data for players and the ball. It can generate either: - - A single PNG image (if either fps or end_timestamp is None, or both are None) - - An MP4 video (if both fps and end_timestamp are provided) - - Parameters - ---------- - file_path : str - The output path where the PNG or MP4 file will be saved - fps : int, optional - Frames per second for video output. If None, a static image is generated - timestamp : pl.duration, optional - The starting timestamp to plot. If None, starts from the beginning of available data - end_timestamp : pl.duration, optional - The ending timestamp for video output. If None, a static image is generated - period_id : int, optional - ID of the match period to visualize. If None, all periods are included - team_color_a : str, default "#CD0E61" - Hex color code for Team A visualization - team_color_b : str, default "#0066CC" - Hex color code for Team B visualization - ball_color : str, default "black" - Color for ball visualization - color_by : Literal["ball_owning", "static_home_away"], default "ball_owning" - Method for coloring the teams: - - "ball_owning": Colors teams based on ball possession - - "static_home_away": Uses static colors for home and away teams - - Returns - ------- - None - The function saves the output file to the specified file_path but doesn't return any value - - Notes - ----- - Output file type is determined by parameters: - - PNG: Generated when either fps or end_timestamp is None, or both are None - - MP4: Generated when both fps and end_timestamp are provided - - Raises - ------ - ValueError - If file extension doesn't match the parameters provided (e.g., .mp4 extension - but missing fps or end_timestamp, or .png extension with both fps and end_timestamp) - """ - try: - import matplotlib.animation as animation - import matplotlib.pyplot as plt - from matplotlib.gridspec import GridSpec - except ImportError: - raise ImportError( - "Seems like you don't have matplotlib installed. Please" - " install it using: pip install matplotlib" - ) - - if (fps is None and end_timestamp is not None) or ( - fps is not None and end_timestamp is None - ): - raise ValueError( - "Both 'fps' and 'end_timestamp' must be provided together to generate a video. " - ) - - # Determine the output type based on parameters - generate_video = fps is not None and end_timestamp is not None - - # Get file extension if it exists - path = pathlib.Path(file_path) - file_extension = path.suffix.lower() if path.suffix else "" - - # If no extension, add the appropriate one based on parameters - if not file_extension: - suffix = ".mp4" if generate_video else ".png" - file_path = str(path.with_suffix(suffix)) - - # Otherwise, validate that the extension matches the parameters - else: - if generate_video and file_extension != ".mp4": - raise ValueError( - f"Parameters fps and end_timestamp indicate video output, " - f"but file extension is '{file_extension}'. Use '.mp4' extension for video output." - ) - elif not generate_video and file_extension == ".mp4": - raise ValueError( - "To generate an MP4 video, both 'fps' and 'end_timestamp' must be provided. " - "For static image output, use a '.png' extension." - ) - elif not generate_video and file_extension != ".png": - raise ValueError( - f"For static image output, use '.png' extension instead of '{file_extension}'." - ) - - self._team_color_a = team_color_a - self._team_color_b = team_color_b - self._ball_color = ball_color - self._color_by = color_by - - if period_id is not None and not isinstance(period_id, int): - raise TypeError("period_id should be of type integer") - - if all(x is None for x in [timestamp, end_timestamp, period_id]): - # No filters specified, use the entire dataset - df = self.dataset - elif timestamp is not None and period_id is not None: - if end_timestamp is not None: - # Both timestamp and end_timestamp provided - filter for a range - df = self.dataset.filter( - (pl.col(Column.TIMESTAMP).is_between(timestamp, end_timestamp)) - & (pl.col(Column.PERIOD_ID) == period_id) - ) - else: - # Only timestamp provided (no end_timestamp) - filter for specific timestamp - df = self.dataset.filter( - (pl.col(Column.TIMESTAMP) == timestamp) - & (pl.col(Column.PERIOD_ID) == period_id) - ) - # Handle the case where a single timestamp has multiple frame_ids - df = ( - df.with_columns( - pl.col(Column.FRAME_ID) - .rank(method="min") - .over(Column.TIMESTAMP) - .alias("frame_rank") - ) - # Keep only rows where the frame has rank = 1 (first frame for each timestamp) - .filter(pl.col("frame_rank") == 1).drop("frame_rank") - ) - else: - raise ValueError( - "Please specify both timestamp and period_id, or specify all of timestamp, end_timestamp, and period_id, or none of them." - ) - - if df.is_empty(): - raise ValueError("Selection is empty, please try different timestamp(s)") - - def plot_graph(): - import matplotlib.pyplot as plt - - # Plot node features in top-left - ax1 = self._fig.add_subplot(self._gs[0, 0]) - ax1.imshow(self._graph.x, aspect="auto", cmap="YlOrRd") - ax1.set_xlabel(f"Node Features {self._graph.x.shape}") - - # Set y labels to integers - num_rows = self._graph.x.shape[0] - ax1.set_yticks(range(num_rows)) - ax1.set_yticklabels([str(i) for i in range(num_rows)]) - - node_feature_yticklabels = feature_ticklabels(self._node_feature_dims) - ax1.xaxis.set_ticks_position("top") - ax1.set_xticks(range(len(node_feature_yticklabels))) - ax1.set_xticklabels(node_feature_yticklabels, rotation=45, ha="left") - - # Plot ajacency matrix in bottom-left - ax2 = self._fig.add_subplot(self._gs[1, 0]) - ax2.imshow(self._graph.a.toarray(), aspect="auto", cmap="YlOrRd") - ax2.set_xlabel(f"Adjacency Matrix {self._graph.a.shape}") - - # Set both x and y labels to integers - num_rows_a = self._graph.a.toarray().shape[0] - num_cols_a = self._graph.a.toarray().shape[1] - - ax2.set_yticks(range(num_rows_a)) - ax2.set_yticklabels([str(i) for i in range(num_rows_a)]) - ax2.xaxis.set_ticks_position("top") - ax2.set_xticks(range(num_cols_a)) - ax2.set_xticklabels([str(i) for i in range(num_cols_a)]) - - # Plot Edge Features on the right (spanning both rows) - ax3 = self._fig.add_subplot(self._gs[:, 1]) - - _, size_a = non_zeros(self._graph.a.toarray()[0 : self._ball_carrier_idx]) - ball_carrier_edge_idx, num_rows_e = non_zeros( - np.asarray( - [list(x) for x in self._graph.a.toarray()][self._ball_carrier_idx] - ) - ) - - im3 = ax3.imshow( - self._graph.e[size_a : num_rows_e + size_a, :], - aspect="auto", - cmap="YlOrRd", - ) - - ax3.set_yticks(range(num_rows_e)) - ax3.set_yticklabels(list(ball_carrier_edge_idx[0]), fontsize=18) - ax3.set_xlabel(f"Edge Features {self._graph.e.shape}") - - labels = ax3.get_yticklabels() - if self._ball_carrier_idx in ball_carrier_edge_idx[0]: - idx_position = list(ball_carrier_edge_idx[0]).index( - self._ball_carrier_idx - ) - # Modify just that specific label - labels[idx_position].set_color(self._ball_carrier_color) - labels[idx_position].set_fontweight("bold") - # Set the modified labels back - ax3.set_yticklabels(labels) - - # Set x labels to edge function names at the top, rotated 45 degrees - edge_feature_xticklabels = feature_ticklabels(self._edge_feature_dims) - ax3.xaxis.set_ticks_position("top") - ax3.set_xticks(range(len(edge_feature_xticklabels))) - ax3.set_xticklabels(edge_feature_xticklabels, rotation=45, ha="left") - - plt.colorbar(im3, ax=ax3, fraction=0.1, pad=0.2) - - def plot_vertical_pitch(frame_data: pl.DataFrame): - try: - from mplsoccer import VerticalPitch - except ImportError: - raise ImportError( - "Seems like you don't have mplsoccer installed. Please" - " install it using: pip install mplsoccer" - ) - - ax4 = self._fig.add_subplot(self._gs[:, 2]) - pitch = VerticalPitch( - pitch_type="secondspectrum", - pitch_length=self.pitch_dimensions.pitch_length, - pitch_width=self.pitch_dimensions.pitch_width, - pitch_color="#ffffff", - pad_top=-0.05, - ) - pitch.draw(ax=ax4) - player_and_ball(frame_data=frame_data, ax=ax4) - direction_of_play_arrow(ax=ax4) - - def feature_ticklabels(feature_dims): - _feature_ticklabels = [] - for key, value in feature_dims.items(): - if value == 1: - _feature_ticklabels.append(key) - else: - _feature_ticklabels.extend([key] + [None] * (value - 1)) - return _feature_ticklabels - - def direction_of_play_arrow(ax): - arrow_x = -30 - arrow_y = -7.5 - arrow_dx = 0 - arrow_dy = 15 - - if self.settings.orientation == Orientation.STATIC_HOME_AWAY: - if self._ball_owning_team_id != self.settings.home_team_id: - arrow_y = arrow_y * -1 - arrow_dy = arrow_dy * -1 - elif self.settings.orientation == Orientation.BALL_OWNING_TEAM: - pass - else: - raise ValueError(f"Unsupported orientation {self.settings.orientation}") - - # Create the arrow to indicate direction of play - ax.arrow( - arrow_x, - arrow_y, - arrow_dx, - arrow_dy, - head_width=3, - head_length=2, - fc="#c2c2c2", - ec="#c2c2c2", - width=0.5, - length_includes_head=True, - zorder=1, - ) - - def player_and_ball(frame_data, ax): - if self._color_by == "ball_owning": - team_id = self._ball_owning_team_id - elif self._color_by == "static_home_away": - team_id = self.settings.home_team_id - else: - raise ValueError(f"Unsupported color_by {self._color_by}") - - self._ball_carrier_color = None - - for i, r in enumerate(frame_data.iter_rows(named=True)): - v, vy, vx, y, x = ( - r[Column.SPEED], - r[Column.VX], - r[Column.VY], - r[Column.X], - r[Column.Y], - ) - is_ball = True if r[Column.TEAM_ID] == self.settings.ball_id else False - - if not is_ball: - if team_id is None: - team_id = r[Column.TEAM_ID] - - color = ( - self._team_color_a - if r[Column.TEAM_ID] == team_id - else self._team_color_b - ) - - if r[Column.IS_BALL_CARRIER] == True: - self._ball_carrier_color = color - - ax.scatter(x, y, color=color, s=450) - - if v > 1.0: - ax.annotate( - "", - xy=(x + vx, y + vy), - xytext=(x, y), - arrowprops=dict(arrowstyle="->", color=color, lw=3), - ) - - else: - ax.scatter(x, y, color=self._ball_color, s=250, zorder=10) - # # Text with white border - text = ax.text( - x + (-1.2 if is_ball else 0.0), - y + (-1.2 if is_ball else 0.0), - i, - color=self._ball_color if is_ball else color, - fontsize=12, - ha="center", - va="center", - zorder=15 if is_ball else 5, - ) - - import matplotlib.patheffects as path_effects - - text.set_path_effects( - [ - path_effects.Stroke(linewidth=6, foreground="white"), - path_effects.Normal(), - ] - ) - ax.set_xlabel(f"Label: {frame_data['label'][0]}", fontsize=22) - - def frame_plot(self, frame_data): - self._gs = GridSpec( - 2, - 3, - width_ratios=[2, 1, 3], - height_ratios=[1, 1], - wspace=0.1, - hspace=0.06, - left=0.05, - right=1.0, - bottom=0.05, - ) - - # Process the current frame - features = self._compute([frame_data[col] for col in self._exprs_variables]) - a = make_sparse( - reshape_from_size( - features["a"], features["a_shape_0"], features["a_shape_1"] - ) - ) - x = reshape_from_size( - features["x"], features["x_shape_0"], features["x_shape_1"] - ) - e = reshape_from_size( - features["e"], features["e_shape_0"], features["e_shape_1"] - ) - y = np.asarray([features[self.label_column]]) - - self._graph = Graph( - a=a, - x=x, - e=e, - y=y, - ) - - self._ball_carrier_idx = np.where( - frame_data[Column.IS_BALL_CARRIER] == True - )[0][0] - self._ball_owning_team_id = list(frame_data[Column.BALL_OWNING_TEAM_ID])[0] - - plot_vertical_pitch(frame_data) - plot_graph() - - plt.tight_layout() - - self._fig = plt.figure(figsize=(25, 18)) - self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05) - - if sort: - df = self._sort(df) - - if generate_video: - writer = animation.FFMpegWriter(fps=fps, bitrate=1800) - - with writer.saving(self._fig, file_path, dpi=300): - for group_id, frame_data in df.group_by( - Group.BY_FRAME, maintain_order=True - ): - self._fig.clear() - frame_plot(self, frame_data) - writer.grab_frame() - - else: - frame_plot(self, frame_data=df) - plt.savefig(file_path, dpi=300) diff --git a/unravel/soccer/graphs/graph_frame.py b/unravel/soccer/graphs/graph_frame.py deleted file mode 100644 index a0c5b9f4..00000000 --- a/unravel/soccer/graphs/graph_frame.py +++ /dev/null @@ -1,104 +0,0 @@ -import numpy as np - -import warnings - -from typing import Union - -from dataclasses import dataclass, field - -from warnings import * - -from spektral.data import Graph - -from .features import ( - delaunay_adjacency_matrix, - adjacency_matrix, - node_features, - edge_features, -) -from ...utils import ( - DefaultGraphSettings, - DefaultTrackingModel, - DefaultGraphFrame, - AdjacencyMatrixType, - AdjacenyMatrixConnectType, - AdjcacenyMatrixTypeNotSetException, -) - - -@dataclass -class GraphFrame(DefaultGraphFrame): - - def to_spektral_graph(self) -> Graph: - if self.graph_data: - return Graph( - x=self.graph_data["x"], - a=self.graph_data["a"], - e=self.graph_data["e"], - y=self.graph_data["y"], - id=self.graph_id, - ) - else: - return None - - def _adjaceny_matrix(self): - """ - Create adjeceny matrices. If we specify the Adjaceny Matrix type to be Delaunay it's created as the 'general' A, - else we create a seperate one as A_delaunay. - This way we can use the Delaunay matrix in the Edge Features if it's not used as the Adj Matrix - """ - if not self.settings.adjacency_matrix_type: - raise AdjcacenyMatrixTypeNotSetException( - "AdjacencyMatrixTypeNotSet Error... Please set `adjacency_matrix_type`..." - ) - elif self.settings.adjacency_matrix_type == AdjacencyMatrixType.DELAUNAY: - A = delaunay_adjacency_matrix( - self.data.attacking_players, - self.data.defending_players, - self.settings.adjacency_matrix_connect_type, - self.data.ball_carrier_idx, - self.settings.self_loop_ball, - ) - A_delaunay = None - else: - A = adjacency_matrix( - self.data.attacking_players, - self.data.defending_players, - self.settings.adjacency_matrix_connect_type, - self.settings.adjacency_matrix_type, - self.data.ball_carrier_idx, - ) - A_delaunay = delaunay_adjacency_matrix( - self.data.attacking_players, - self.data.defending_players, - self.settings.adjacency_matrix_connect_type, - self.data.ball_carrier_idx, - self.settings.self_loop_ball, - ) - return A, A_delaunay - - def _node_features(self): - return node_features( - attacking_players=self.data.attacking_players, - defending_players=self.data.defending_players, - ball=self.data.ball, - max_player_speed=self.settings.max_player_speed, - max_ball_speed=self.settings.max_ball_speed, - ball_carrier_idx=self.data.ball_carrier_idx, - pitch_dimensions=self.settings.pitch_dimensions, - include_ball_node=True, - defending_team_node_value=self.settings.defending_team_node_value, - non_potential_receiver_node_value=self.settings.non_potential_receiver_node_value, - ) - - def _edge_features(self, A, A_delaunay): - return edge_features( - self.data.attacking_players, - self.data.defending_players, - self.data.ball, - self.settings.max_player_speed, - self.settings.max_ball_speed, - self.settings.pitch_dimensions, - A, - A_delaunay, - ) diff --git a/unravel/soccer/graphs/graph_settings.py b/unravel/soccer/graphs/graph_settings.py index f066a2dd..88cdd586 100644 --- a/unravel/soccer/graphs/graph_settings.py +++ b/unravel/soccer/graphs/graph_settings.py @@ -2,15 +2,31 @@ from ...utils import DefaultGraphSettings -from kloppy.domain import MetricPitchDimensions +from dataclasses import dataclass, field +from kloppy.domain import MetricPitchDimensions, Orientation + +from ..dataset import Constant + +import numpy as np @dataclass -class SoccerGraphSettings(DefaultGraphSettings): - infer_goalkeepers: bool = True +class GraphSettingsPolars(DefaultGraphSettings): + ball_id: str = Constant.BALL + goalkeeper_id: str = "GK" boundary_correction: float = None + home_team_id: str = None + away_team_id: str = None + orientation: Orientation = None non_potential_receiver_node_value: float = 0.1 ball_carrier_treshold: float = 25.0 + pitch_dimensions: MetricPitchDimensions = field( + init=False, repr=False, default_factory=MetricPitchDimensions + ) + + def __post_init__(self): + self._sport_specific_checks() + self._set_additional_settings() @property def pitch_dimensions(self) -> int: @@ -19,9 +35,23 @@ def pitch_dimensions(self) -> int: @pitch_dimensions.setter def pitch_dimensions(self, pitch_dimensions: MetricPitchDimensions) -> None: self._pitch_dimensions = pitch_dimensions + self._set_additional_settings() def _sport_specific_checks(self): if self.non_potential_receiver_node_value > 1: self.non_potential_receiver_node_value = 1 elif self.non_potential_receiver_node_value < 0: self.non_potential_receiver_node_value = 0 + + def _set_additional_settings(self): + self.max_distance = np.sqrt( + self.pitch_dimensions.pitch_length**2 + self.pitch_dimensions.pitch_width**2 + ) + self.max_goal_distance = np.sqrt( + self.pitch_dimensions.pitch_length**2 + self.pitch_dimensions.pitch_width**2 + ) + self.goal_mouth_position = ( + self.pitch_dimensions.x_dim.max, + (self.pitch_dimensions.y_dim.max + self.pitch_dimensions.y_dim.min) / 2, + 0.0, + ) diff --git a/unravel/soccer/graphs/graph_settings_pl.py b/unravel/soccer/graphs/graph_settings_pl.py deleted file mode 100644 index 88cdd586..00000000 --- a/unravel/soccer/graphs/graph_settings_pl.py +++ /dev/null @@ -1,57 +0,0 @@ -from dataclasses import dataclass - -from ...utils import DefaultGraphSettings - -from dataclasses import dataclass, field -from kloppy.domain import MetricPitchDimensions, Orientation - -from ..dataset import Constant - -import numpy as np - - -@dataclass -class GraphSettingsPolars(DefaultGraphSettings): - ball_id: str = Constant.BALL - goalkeeper_id: str = "GK" - boundary_correction: float = None - home_team_id: str = None - away_team_id: str = None - orientation: Orientation = None - non_potential_receiver_node_value: float = 0.1 - ball_carrier_treshold: float = 25.0 - pitch_dimensions: MetricPitchDimensions = field( - init=False, repr=False, default_factory=MetricPitchDimensions - ) - - def __post_init__(self): - self._sport_specific_checks() - self._set_additional_settings() - - @property - def pitch_dimensions(self) -> int: - return self._pitch_dimensions - - @pitch_dimensions.setter - def pitch_dimensions(self, pitch_dimensions: MetricPitchDimensions) -> None: - self._pitch_dimensions = pitch_dimensions - self._set_additional_settings() - - def _sport_specific_checks(self): - if self.non_potential_receiver_node_value > 1: - self.non_potential_receiver_node_value = 1 - elif self.non_potential_receiver_node_value < 0: - self.non_potential_receiver_node_value = 0 - - def _set_additional_settings(self): - self.max_distance = np.sqrt( - self.pitch_dimensions.pitch_length**2 + self.pitch_dimensions.pitch_width**2 - ) - self.max_goal_distance = np.sqrt( - self.pitch_dimensions.pitch_length**2 + self.pitch_dimensions.pitch_width**2 - ) - self.goal_mouth_position = ( - self.pitch_dimensions.x_dim.max, - (self.pitch_dimensions.y_dim.max + self.pitch_dimensions.y_dim.min) / 2, - 0.0, - ) diff --git a/unravel/utils/objects/default_graph_converter.py b/unravel/utils/objects/default_graph_converter.py index 75a968bc..71c456ae 100644 --- a/unravel/utils/objects/default_graph_converter.py +++ b/unravel/utils/objects/default_graph_converter.py @@ -142,7 +142,16 @@ def __post_init__(self): raise Exception("'verbose' should be of type boolean (bool)") def _shuffle(self): - raise NotImplementedError() + if self.settings.random_seed is None or self.settings.random_seed == False: + self.dataset = self._sort(self.dataset) + if isinstance(self.settings.random_seed, int): + self.dataset = self.dataset.sample( + fraction=1.0, seed=self.settings.random_seed + ) + elif self.settings.random_seed == True: + self.dataset = self.dataset.sample(fraction=1.0) + else: + self.dataset = self._sort(self.dataset) def _sport_specific_checks(self): raise NotImplementedError( From 947edf443754ff09cb7b2c7d838d5921eb7026e1 Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Fri, 23 May 2025 09:44:41 +0200 Subject: [PATCH 3/6] refactor default graph converter --- .../graphs/graph_converter.py | 131 +++++------------- unravel/soccer/graphs/graph_converter.py | 113 --------------- .../utils/objects/default_graph_converter.py | 131 +++++++++++++++++- 3 files changed, 161 insertions(+), 214 deletions(-) diff --git a/unravel/american_football/graphs/graph_converter.py b/unravel/american_football/graphs/graph_converter.py index ae8310a6..3bbbb701 100644 --- a/unravel/american_football/graphs/graph_converter.py +++ b/unravel/american_football/graphs/graph_converter.py @@ -171,7 +171,7 @@ def _apply_graph_settings(self, settings): ) @property - def __exprs_variables(self): + def _exprs_variables(self): exprs_variables = [ Column.X, Column.Y, @@ -194,8 +194,8 @@ def __exprs_variables(self): ) return exprs - def __compute(self, args: List[pl.Series]) -> dict: - d = {col: args[i].to_numpy() for i, col in enumerate(self.__exprs_variables)} + def _compute(self, args: List[pl.Series]) -> dict: + d = {col: args[i].to_numpy() for i, col in enumerate(self._exprs_variables)} if self.graph_feature_cols is not None: failed = [ @@ -277,32 +277,14 @@ def __compute(self, args: List[pl.Series]) -> dict: self.label_column: d[self.label_column][0], } - @property - def return_dtypes(self): - return pl.Struct( - { - "e": pl.List(pl.List(pl.Float64)), - "x": pl.List(pl.List(pl.Float64)), - "a": pl.List(pl.List(pl.Float64)), - "e_shape_0": pl.Int64, - "e_shape_1": pl.Int64, - "x_shape_0": pl.Int64, - "x_shape_1": pl.Int64, - "a_shape_0": pl.Int64, - "a_shape_1": pl.Int64, - self.graph_id_column: pl.String, - self.label_column: pl.Int64, - } - ) - def _convert(self): # Group and aggregate in one step return ( self.dataset.group_by(Group.BY_FRAME, maintain_order=True) .agg( pl.map_groups( - exprs=self.__exprs_variables, - function=self.__compute, + exprs=self._exprs_variables, + function=self._compute, return_dtype=self.return_dtypes, ).alias("result_dict") ) @@ -330,76 +312,33 @@ def _convert(self): .drop("result_dict") ) - def to_graph_frames(self) -> List[dict]: - def process_chunk(chunk: pl.DataFrame) -> List[dict]: - return [ - { - "a": make_sparse( - reshape_from_size( - chunk["a"][i], chunk["a_shape_0"][i], chunk["a_shape_1"][i] - ) - ), - "x": reshape_from_size( - chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i] - ), - "e": reshape_from_size( - chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i] - ), - "y": np.asarray([chunk[self.label_column][i]]), - "id": chunk[self.graph_id_column][i], - } - for i in range(len(chunk)) - ] - - graph_df = self._convert() - self.graph_frames = [ - graph - for chunk in graph_df.lazy() - .collect(engine="gpu") - .iter_slices(self.chunk_size) - for graph in process_chunk(chunk) - ] - return self.graph_frames - - def to_spektral_graphs(self) -> List[Graph]: - if not self.graph_frames: - self.to_graph_frames() - - return [ - Graph( - x=d["x"], - a=d["a"], - e=d["e"], - y=d["y"], - id=d["id"], - ) - for d in self.graph_frames - ] - - def to_pickle(self, file_path: str, verbose: bool = False) -> None: - """ - We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y - To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data) - """ - if not file_path.endswith("pickle.gz"): - raise ValueError( - "Only compressed pickle files of type 'some_file_name.pickle.gz' are supported..." - ) - - if not self.graph_frames: - self.to_graph_frames() - - if verbose: - print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...") - - import pickle - import gzip - from pathlib import Path - - path = Path(file_path) - - directories = path.parent - directories.mkdir(parents=True, exist_ok=True) - - with gzip.open(file_path, "wb") as file: - pickle.dump(self.graph_frames, file) + # def to_graph_frames(self) -> List[dict]: + # def process_chunk(chunk: pl.DataFrame) -> List[dict]: + # return [ + # { + # "a": make_sparse( + # reshape_from_size( + # chunk["a"][i], chunk["a_shape_0"][i], chunk["a_shape_1"][i] + # ) + # ), + # "x": reshape_from_size( + # chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i] + # ), + # "e": reshape_from_size( + # chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i] + # ), + # "y": np.asarray([chunk[self.label_column][i]]), + # "id": chunk[self.graph_id_column][i], + # } + # for i in range(len(chunk)) + # ] + + # graph_df = self._convert() + # self.graph_frames = [ + # graph + # for chunk in graph_df.lazy() + # .collect(engine="gpu") + # .iter_slices(self.chunk_size) + # for graph in process_chunk(chunk) + # ] + # return self.graph_frames diff --git a/unravel/soccer/graphs/graph_converter.py b/unravel/soccer/graphs/graph_converter.py index b95c5711..7f4abb07 100644 --- a/unravel/soccer/graphs/graph_converter.py +++ b/unravel/soccer/graphs/graph_converter.py @@ -132,27 +132,6 @@ def _sample(self): pl.col(Column.FRAME_ID) % (1.0 / self.sample_rate) == 0 ) - def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]): - for i, func in enumerate(funcs): - # Check if it has the attributes added by the decorator - if not hasattr(func, "feature_type"): - func_str = inspect.getsource(func).strip() - raise Exception( - f"Error processing feature function:\n" - f"{func.__name__} defined as:\n" - f"{func_str}\n\n" - "Function is missing the @graph_feature decorator. " - ) - - if func.feature_type != feature_type: - func_str = inspect.getsource(func).strip() - raise Exception( - f"Error processing feature function:\n" - f"{func.__name__} defined as:\n" - f"{func_str}\n\n" - "Function has an incorrect feature type edge features should be 'edge', node features should be 'node'. " - ) - @staticmethod def _sort(df): sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - ( @@ -608,24 +587,6 @@ def _compute(self, args: List[pl.Series]) -> dict: self.label_column: frame_data[self.label_column][0], } - @property - def return_dtypes(self): - return pl.Struct( - { - "e": pl.List(pl.List(pl.Float64)), - "x": pl.List(pl.List(pl.Float64)), - "a": pl.List(pl.List(pl.Float64)), - "e_shape_0": pl.Int64, - "e_shape_1": pl.Int64, - "x_shape_0": pl.Int64, - "x_shape_1": pl.Int64, - "a_shape_0": pl.Int64, - "a_shape_1": pl.Int64, - self.graph_id_column: pl.String, - self.label_column: pl.Int64, - } - ) - def _convert(self): # Group and aggregate in one step return ( @@ -661,80 +622,6 @@ def _convert(self): .drop("result_dict") ) - def to_graph_frames(self) -> List[dict]: - def process_chunk(chunk: pl.DataFrame) -> List[dict]: - return [ - { - "a": make_sparse( - reshape_from_size( - chunk["a"][i], chunk["a_shape_0"][i], chunk["a_shape_1"][i] - ) - ), - "x": reshape_from_size( - chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i] - ), - "e": reshape_from_size( - chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i] - ), - "y": np.asarray([chunk[self.label_column][i]]), - "id": chunk[self.graph_id_column][i], - } - for i in range(len(chunk)) - ] - - graph_df = self._convert() - self.graph_frames = [ - graph - for chunk in graph_df.lazy() - .collect(engine="gpu") - .iter_slices(self.chunk_size) - for graph in process_chunk(chunk) - ] - return self.graph_frames - - def to_spektral_graphs(self) -> List[Graph]: - if not self.graph_frames: - self.to_graph_frames() - - return [ - Graph( - x=d["x"], - a=d["a"], - e=d["e"], - y=d["y"], - id=d["id"], - ) - for d in self.graph_frames - ] - - def to_pickle(self, file_path: str, verbose: bool = False) -> None: - """ - We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y - To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data) - """ - if not file_path.endswith("pickle.gz"): - raise ValueError( - "Only compressed pickle files of type 'some_file_name.pickle.gz' are supported..." - ) - - if not self.graph_frames: - self.to_graph_frames() - - if verbose: - print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...") - - import pickle - import gzip - from pathlib import Path - - path = Path(file_path) - - directories = path.parent - directories.mkdir(parents=True, exist_ok=True) - - with gzip.open(file_path, "wb") as file: - pickle.dump(self.graph_frames, file) - def plot( self, file_path: str, diff --git a/unravel/utils/objects/default_graph_converter.py b/unravel/utils/objects/default_graph_converter.py index 71c456ae..2ba9e1f8 100644 --- a/unravel/utils/objects/default_graph_converter.py +++ b/unravel/utils/objects/default_graph_converter.py @@ -1,5 +1,6 @@ import logging import sys +import inspect from dataclasses import dataclass, field, asdict @@ -23,6 +24,10 @@ from .default_graph_settings import DefaultGraphSettings from .custom_spektral_dataset import CustomSpektralDataset +from ..features.utils import make_sparse, reshape_from_size + +import numpy as np + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) stdout_handler = logging.StreamHandler(sys.stdout) @@ -164,17 +169,63 @@ def _apply_graph_settings(self): def _convert(self): raise NotImplementedError() - def to_graph_frames(self) -> dict: - raise NotImplementedError() + def to_spektral_graphs(self) -> List[Graph]: + if not self.graph_frames: + self.to_graph_frames() - def to_pickle(self) -> None: - raise NotImplementedError() + return [ + Graph( + x=d["x"], + a=d["a"], + e=d["e"], + y=d["y"], + id=d["id"], + ) + for d in self.graph_frames + ] + + def to_pickle(self, file_path: str, verbose: bool = False) -> None: + """ + We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y + To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data) + """ + if not file_path.endswith("pickle.gz"): + raise ValueError( + "Only compressed pickle files of type 'some_file_name.pickle.gz' are supported..." + ) + + if not self.graph_frames: + self.to_graph_frames() + + if verbose: + print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...") + + import pickle + import gzip + from pathlib import Path + + path = Path(file_path) + + directories = path.parent + directories.mkdir(parents=True, exist_ok=True) + + with gzip.open(file_path, "wb") as file: + pickle.dump(self.graph_frames, file) def to_spektral_graphs(self) -> List[Graph]: if not self.graph_frames: self.to_graph_frames() - return [g.to_spektral_graph() for g in self.graph_frames] + return [ + Graph( + x=d["x"], + a=d["a"], + e=d["e"], + y=d["y"], + id=d["id"], + ) + for d in self.graph_frames + ] def to_custom_dataset(self) -> CustomSpektralDataset: """ @@ -182,3 +233,73 @@ def to_custom_dataset(self) -> CustomSpektralDataset: for docs see https://graphneural.network/creating-dataset/ """ return CustomSpektralDataset(graphs=self.to_spektral_graphs()) + + def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]): + for i, func in enumerate(funcs): + # Check if it has the attributes added by the decorator + if not hasattr(func, "feature_type"): + func_str = inspect.getsource(func).strip() + raise Exception( + f"Error processing feature function:\n" + f"{func.__name__} defined as:\n" + f"{func_str}\n\n" + "Function is missing the @graph_feature decorator. " + ) + + if func.feature_type != feature_type: + func_str = inspect.getsource(func).strip() + raise Exception( + f"Error processing feature function:\n" + f"{func.__name__} defined as:\n" + f"{func_str}\n\n" + "Function has an incorrect feature type edge features should be 'edge', node features should be 'node'. " + ) + + @property + def return_dtypes(self): + return pl.Struct( + { + "e": pl.List(pl.List(pl.Float64)), + "x": pl.List(pl.List(pl.Float64)), + "a": pl.List(pl.List(pl.Float64)), + "e_shape_0": pl.Int64, + "e_shape_1": pl.Int64, + "x_shape_0": pl.Int64, + "x_shape_1": pl.Int64, + "a_shape_0": pl.Int64, + "a_shape_1": pl.Int64, + self.graph_id_column: pl.String, + self.label_column: pl.Int64, + } + ) + + def to_graph_frames(self) -> List[dict]: + def process_chunk(chunk: pl.DataFrame) -> List[dict]: + return [ + { + "a": make_sparse( + reshape_from_size( + chunk["a"][i], chunk["a_shape_0"][i], chunk["a_shape_1"][i] + ) + ), + "x": reshape_from_size( + chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i] + ), + "e": reshape_from_size( + chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i] + ), + "y": np.asarray([chunk[self.label_column][i]]), + "id": chunk[self.graph_id_column][i], + } + for i in range(len(chunk)) + ] + + graph_df = self._convert() + self.graph_frames = [ + graph + for chunk in graph_df.lazy() + .collect(engine="gpu") + .iter_slices(self.chunk_size) + for graph in process_chunk(chunk) + ] + return self.graph_frames From dbaa3e2a534d193acfcbe5eec164e35de6ba302b Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Fri, 23 May 2025 09:49:29 +0200 Subject: [PATCH 4/6] remove dangling files --- .../graphs/graph_converter.py | 31 -- unravel/utils/objects/__init__.py | 4 - .../utils/objects/custom_spektral_dataset.py | 8 - unravel/utils/objects/default_ball.py | 56 ---- unravel/utils/objects/default_graph_frame.py | 163 ---------- unravel/utils/objects/default_player.py | 59 ---- unravel/utils/objects/default_tracking.py | 294 ------------------ 7 files changed, 615 deletions(-) delete mode 100644 unravel/utils/objects/default_ball.py delete mode 100644 unravel/utils/objects/default_graph_frame.py delete mode 100644 unravel/utils/objects/default_player.py delete mode 100644 unravel/utils/objects/default_tracking.py diff --git a/unravel/american_football/graphs/graph_converter.py b/unravel/american_football/graphs/graph_converter.py index 3bbbb701..327f94bc 100644 --- a/unravel/american_football/graphs/graph_converter.py +++ b/unravel/american_football/graphs/graph_converter.py @@ -311,34 +311,3 @@ def _convert(self): ) .drop("result_dict") ) - - # def to_graph_frames(self) -> List[dict]: - # def process_chunk(chunk: pl.DataFrame) -> List[dict]: - # return [ - # { - # "a": make_sparse( - # reshape_from_size( - # chunk["a"][i], chunk["a_shape_0"][i], chunk["a_shape_1"][i] - # ) - # ), - # "x": reshape_from_size( - # chunk["x"][i], chunk["x_shape_0"][i], chunk["x_shape_1"][i] - # ), - # "e": reshape_from_size( - # chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i] - # ), - # "y": np.asarray([chunk[self.label_column][i]]), - # "id": chunk[self.graph_id_column][i], - # } - # for i in range(len(chunk)) - # ] - - # graph_df = self._convert() - # self.graph_frames = [ - # graph - # for chunk in graph_df.lazy() - # .collect(engine="gpu") - # .iter_slices(self.chunk_size) - # for graph in process_chunk(chunk) - # ] - # return self.graph_frames diff --git a/unravel/utils/objects/__init__.py b/unravel/utils/objects/__init__.py index e513679b..d4356f60 100644 --- a/unravel/utils/objects/__init__.py +++ b/unravel/utils/objects/__init__.py @@ -1,8 +1,4 @@ -from .default_player import DefaultPlayer -from .default_ball import DefaultBall -from .default_tracking import DefaultTrackingModel from .custom_spektral_dataset import CustomSpektralDataset -from .default_graph_frame import DefaultGraphFrame from .default_graph_settings import DefaultGraphSettings from .default_graph_converter import DefaultGraphConverter from .default_dataset import DefaultDataset diff --git a/unravel/utils/objects/custom_spektral_dataset.py b/unravel/utils/objects/custom_spektral_dataset.py index e280251c..230d51c2 100644 --- a/unravel/utils/objects/custom_spektral_dataset.py +++ b/unravel/utils/objects/custom_spektral_dataset.py @@ -19,8 +19,6 @@ from spektral.data import Dataset, Graph from spektral.data.utils import get_spec -from .default_graph_frame import DefaultGraphFrame - from ..exceptions import NoGraphIdsWarning @@ -84,12 +82,6 @@ def __convert(self, data) -> List[Graph]: """ if isinstance(data[0], Graph): return [g for i, g in enumerate(data) if i % self.sample == 0] - elif isinstance(data[0], DefaultGraphFrame): - return [ - g.to_spektral_graph() - for i, g in enumerate(self.data) - if i % self.sample == 0 - ] elif isinstance(data[0], dict): return [ Graph(x=g["x"], a=g["a"], e=g["e"], y=g["y"], id=g["id"]) diff --git a/unravel/utils/objects/default_ball.py b/unravel/utils/objects/default_ball.py deleted file mode 100644 index d5817f48..00000000 --- a/unravel/utils/objects/default_ball.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -from dataclasses import dataclass - - -@dataclass -class DefaultBall(object): - fps: int - x1: float = np.nan - y1: float = np.nan - z1: float = 0.0 - x2: float = np.nan - y2: float = np.nan - z2: float = 0.0 - position = np.array([np.nan, np.nan]) - position3D = np.array([np.nan, np.nan, np.nan]) - - def __post_init__(self): - self.position = np.array([self.x1, self.y1]) - self.position3D = np.array([self.x1, self.y1, self.z1]) - self.next_position = np.array([self.x2, self.y2]) - self.next_position3D = np.array([self.x2, self.y2, self.z2]) - - self.set_velocity() - - def set_velocity(self): - delta_time = 1.0 / self.fps - - if not ( - np.any(np.isnan(self.next_position3D)) or np.any(np.isnan(self.position3D)) - ): - vx = (self.next_position3D[0] - self.position3D[0]) / delta_time - vy = (self.next_position3D[1] - self.position3D[1]) / delta_time - vz = (self.next_position3D[2] - self.position3D[2]) / delta_time - else: - vx = 0 - vy = 0 - vz = 0 - - self.velocity = np.asarray([vx, vy], dtype=float) - self.velocity3D = np.asarray([vx, vy, vz], dtype=float) - - if np.any(np.isnan(self.velocity)): - self.velocity = np.asarray([0.0, 0.0], dtype=float) - self.velocity3D = np.asarray([0.0, 0.0, 0.0], dtype=float) - - self.speed = np.sqrt(vx**2 + vy**2 + vz**2) - - def invert_position(self): - self.next_position = self.next_position * -1.0 - self.position = self.position * -1.0 - self.x1 = self.x1 * -1.0 - self.y1 = self.y1 * -1.0 - self.x2 = self.x2 * -1.0 - self.y2 = self.y2 * -1.0 - self.set_velocity() - return self diff --git a/unravel/utils/objects/default_graph_frame.py b/unravel/utils/objects/default_graph_frame.py deleted file mode 100644 index d3ca48d5..00000000 --- a/unravel/utils/objects/default_graph_frame.py +++ /dev/null @@ -1,163 +0,0 @@ -import numpy as np - -import warnings - -from typing import Union - -from dataclasses import dataclass, field - -from ..exceptions.warnings import QualityCheckWarning - -from spektral.data import Graph - -from ..features import ( - PredictionLabelType, - make_sparse, -) -from .default_graph_settings import DefaultGraphSettings -from .default_tracking import DefaultTrackingModel - - -@dataclass -class DefaultGraphFrame: - frame_id: int - data: DefaultTrackingModel - settings: DefaultGraphSettings - graph_id: Union[str, int] - label: Union[int, bool] - graph_data: dict = field(init=False, repr=False, default=None) - - def __post_init__(self): - A, A_delaunay = self._adjaceny_matrix() - X = self._node_features() - E = self._edge_features(A, A_delaunay) - - if self.settings.pad: - X, A, E = self._pad(x=X, a=A, e=E) - - if self.settings.random_seed is not False: - X, A, E = self._shuffle(x=X, a=A, e=E) - if not self.settings.label_type == PredictionLabelType.BINARY: - raise NotImplementedError( - "Need to potentially implement a shuffle for Y to follow the shuffling of X, A and E." - ) - - sparse_A = make_sparse(A) - Y = self._label() - - if self._quality_check(X, E): - self.graph_data = dict(x=X, a=sparse_A, e=E, y=Y, id=self.graph_id) - - def to_spektral_graph(self) -> Graph: - if self.graph_data: - return Graph( - x=self.graph_data["x"], - a=self.graph_data["a"], - e=self.graph_data["e"], - y=self.graph_data["y"], - id=self.graph_id, - ) - else: - return None - - def _label(self): - if self.settings.label_type == PredictionLabelType.BINARY: - return np.asarray([int(self.label)]) - else: - raise NotImplementedError("Label should be PredictionLabelType.BINARY") - - def _adjaceny_matrix(self): - """ - Create adjeceny matrices. If we specify the Adjaceny Matrix type to be Delaunay it's created as the 'general' A, - else we create a seperate one as A_delaunay. - This way we can use the Delaunay matrix in the Edge Features if it's not used as the Adj Matrix - """ - raise NotImplementedError() - - def _node_features(self): - raise NotImplementedError() - - def _edge_features(self, A, A_delaunay): - raise NotImplementedError() - - def _quality_check(self, X, E): - if self.settings.boundary_correction is not None: - if (np.max(X) <= 1) or (np.min(X) >= -1): - warnings.warn( - f"""Node Feature(s) outside boundary for frame={self.frame_id}, skipping...""", - QualityCheckWarning, - ) - return False - if (np.max(E) <= 1) or (np.min(E) >= -1): - warnings.warn( - f"""Edge Feature(s) outside boundary for frame={self.frame_id}, skipping...""", - QualityCheckWarning, - ) - return False - return True - - def _shuffle(self, x, a, e): - if isinstance(self.settings.random_seed, int): - np.random.seed(self.settings.random_seed) - elif self.settings.random_seed == True: - np.random.seed() - else: - pass - - # Generate a random permutation of node indices - num_nodes = x.shape[0] - permutation = np.random.permutation(num_nodes) - - # Permute the rows and columns of the adjacency matrix - a_shuffled = a[permutation, :][:, permutation] - - # Permute the rows of the node features matrix - - x_shuffled = x[permutation] - # Adjust the edge features matrix - # Get the indices of non-zero elements in the original adjacency matrix - row, col = np.nonzero(a) - - # Map the original indices to the new shuffled indices - row_shuffled = permutation[row] - col_shuffled = permutation[col] - - # Create a dictionary to map from original index pairs to new index pairs - index_mapping = { - (r, c): (rs, cs) - for r, c, rs, cs in zip(row, col, row_shuffled, col_shuffled) - } - - # Sort the new index pairs to ensure consistency - sorted_indices = sorted(index_mapping.values()) - - # Create an array of edge features based on the sorted indices - e_shuffled = np.zeros_like(e, dtype=float) - - # Populate the new edge features matrix - for idx, (r, c) in enumerate(sorted_indices): - original_index = list(index_mapping.values()).index((r, c)) - e_shuffled[idx] = e[original_index] - - return x_shuffled, a_shuffled, e_shuffled - - def _pad(self, x, a, e): - n_node_features = x.shape[1] - n_edge_features = e.shape[1] - - max_edges = self.settings.pad_settings.max_edges - max_nodes = self.settings.pad_settings.max_nodes - - # Padding node features - pad_x = np.zeros((max_nodes, n_node_features)) - pad_x[: x.shape[0], : x.shape[1]] = x - - # Padding adjacency matrix - pad_a = np.zeros((max_nodes, max_nodes)) - pad_a[: a.shape[0], : a.shape[1]] = a - - # Padding edge features - pad_e = np.zeros((max_edges, n_edge_features)) - pad_e[: e.shape[0], : e.shape[1]] = e - - return pad_x, pad_a, pad_e diff --git a/unravel/utils/objects/default_player.py b/unravel/utils/objects/default_player.py deleted file mode 100644 index db18e2a9..00000000 --- a/unravel/utils/objects/default_player.py +++ /dev/null @@ -1,59 +0,0 @@ -import numpy as np -from dataclasses import dataclass, field - - -@dataclass -class DefaultPlayer(object): - fps: int - x1: float = np.nan - y1: float = np.nan - x2: float = np.nan - y2: float = np.nan - is_visible: bool = False - position: np.array = field( - default_factory=lambda: np.array([np.nan, np.nan], dtype=float) - ) - next_position: np.array = field( - default_factory=lambda: np.array([np.nan, np.nan], dtype=float) - ) - - velocity: np.array = field( - default_factory=lambda: np.asarray([0.0, 0.0], dtype=float) - ) # velocity vector - speed: float = 0.0 # actual speed in m/s - is_gk: bool = False - - def __post_init__(self): - self.next_position = np.asarray([self.x2, self.y2], dtype=float) - self.position = np.asarray([self.x1, self.y1], dtype=float) - - self.set_velocity() - - def invert_position(self): - self.next_position = self.next_position * -1.0 - self.position = self.position * -1.0 - self.x1 = self.x1 * -1 - self.y1 = self.y1 * -1 - self.x2 = self.x2 * -1.0 - self.y2 = self.y2 * -1.0 - self.set_velocity() - return self - - def set_velocity(self): - dt = 1.0 / self.fps - if not ( - np.any(np.isnan(self.next_position)) or np.any(np.isnan(self.position)) - ): - vx = (self.next_position[0] - self.position[0]) / dt - vy = (self.next_position[1] - self.position[1]) / dt - else: - vx = 0 - vy = 0 - - self.velocity = np.asarray([vx, vy], dtype=float) - - # Re-check if any component of velocity is NaN and set to zero if it is - if np.any(np.isnan(self.velocity)): - self.velocity = np.asarray([0.0, 0.0], dtype=float) - - self.speed = np.sqrt(vx**2 + vy**2) diff --git a/unravel/utils/objects/default_tracking.py b/unravel/utils/objects/default_tracking.py deleted file mode 100644 index e0fedc68..00000000 --- a/unravel/utils/objects/default_tracking.py +++ /dev/null @@ -1,294 +0,0 @@ -import numpy as np -from typing import Union, Dict, List - -from kloppy.domain import ( - TrackingDataset, - Frame, - Point3D, - Point, - Orientation, - Ground, - AttackingDirection, -) - -from . import DefaultPlayer, DefaultBall -from ..exceptions import ( - InvalidAttackingTeamTypeException, - MissingAttackingTeam, - MissingCoordinates, - NoNextFrameWarning, -) - -import warnings - -from dataclasses import dataclass, field - - -@dataclass -class DefaultTrackingModel: - frame: TrackingDataset - - orientation: Orientation - fps: int - infer_ball_ownership: bool = False - infer_goalkeepers: bool = False - ball_carrier_treshold: bool = 25.0 - verbose: bool = False - pad_n_players: bool = None - - def __post_init__(self): - self.home_players: List[DefaultPlayer] = list() - self.away_players: List[DefaultPlayer] = list() - self.ball: DefaultBall = None - self.attacking_team: str = None - self.ball_carrier_idx: int = None - - self.set_objects_from_frame( - infer_ball_ownership=self.infer_ball_ownership, - infer_goalkeepers=self.infer_goalkeepers, - ball_carrier_treshold=self.ball_carrier_treshold, - orientation=self.orientation, - verbose=self.verbose, - pad_n_players=self.pad_n_players, - ) - - @property - def attacking_players(self): - if self.attacking_team is None: - warnings.warn( - """No key 'attacking_team' found in 'Frame'.""", MissingAttackingTeam - ) - return self.home_players - if not self.attacking_team in [ - Ground.HOME, - Ground.AWAY, - Ground.HOME.value, - Ground.AWAY.value, - ]: - raise InvalidAttackingTeamTypeException( - f"'attacking_team' should be of type {Ground.HOME} or {Ground.AWAY}" - ) - - return ( - self.home_players - if self.attacking_team in [Ground.HOME, Ground.HOME.value] - else self.away_players - ) - - @property - def defending_players(self): - if self.attacking_team is None: - warnings.warn( - """No key 'attacking_team' found in 'Frame'. """, MissingAttackingTeam - ) - return self.away_players - if not self.attacking_team in [ - Ground.HOME, - Ground.AWAY, - Ground.HOME.value, - Ground.AWAY.value, - ]: - raise InvalidAttackingTeamTypeException( - f"'attacking_team' should be of type {Ground.HOME} or {Ground.AWAY}" - ) - - return ( - self.home_players - if self.attacking_team in [Ground.AWAY, Ground.AWAY.value] - else self.away_players - ) - - def _distance_to_ball(self, players): - """ - Use 3D distance to compute distance to the ball. Since we don't have player 3D position, we pad it with 0s - """ - player_positions = np.asarray([p.position for p in players]) - ball_carrier_dist = np.linalg.norm( - np.pad(player_positions, ((0, 0), (0, 1)), "constant") - - self.ball.position3D, - axis=1, - ) - return ball_carrier_dist - - def _set_ball_carrier_idx(self, threshold): - if not self.ball_carrier_idx: - ball_carrier_dist = self._distance_to_ball(players=self.attacking_players) - self.ball_carrier_idx = ( - np.nanargmin(ball_carrier_dist) - if np.nanmin(ball_carrier_dist) < threshold - else None - ) - return self.ball_carrier_idx - return self.ball_carrier_idx - - def _set_goalkeeper(self, players, func): - idx = func([p.y1 for p in players]) - players[idx].is_gk = True - return players - - def _set_attacking_team(self, threshold): - home_ball_carrier_dists = self._distance_to_ball(self.home_players) - closest_home_player_dist = np.nanmin(home_ball_carrier_dists) - away_ball_carrier_dists = self._distance_to_ball(self.away_players) - closest_away_player_dist = np.nanmin(away_ball_carrier_dists) - - if ( - np.nanmin( - np.concatenate((home_ball_carrier_dists, away_ball_carrier_dists)) - ) - < threshold - ): - if closest_home_player_dist < closest_away_player_dist: - self.attacking_team = Ground.HOME - self.ball_carrier_idx = np.nanargmin(home_ball_carrier_dists) - else: - self.attacking_team = Ground.AWAY - self.ball_carrier_idx = np.nanargmin(away_ball_carrier_dists) - - def set_objects_from_frame( - self, - infer_ball_ownership: bool = False, - infer_goalkeepers: bool = False, - ball_carrier_treshold: float = 25.0, - orientation: Orientation = Orientation.NOT_SET, - verbose: bool = True, - pad_n_players: int = None, - ) -> tuple[List[DefaultPlayer], List[DefaultPlayer], DefaultBall, str]: - frame = self.frame - - if isinstance(frame, Frame): - fix_orientation_ltr = ( - True - if orientation == Orientation.STATIC_HOME_AWAY and infer_ball_ownership - else False - ) - - next_frame = frame.next() - - if not next_frame: - if verbose: - warnings.warn( - f"""No next_frame found, skipping...""", NoNextFrameWarning - ) - return None, None, None, None - if not frame.ball_coordinates or not next_frame.ball_coordinates: - if verbose: - warnings.warn( - f"""No ball_coordinates found in frame_id={frame.frame_id}, skipping...""", - MissingCoordinates, - ) - return None, None, None, None - if not frame.players_coordinates: - if verbose: - warnings.warn( - f"""No player_coordinates found in frame_id={frame.frame_id}, skipping...""", - MissingCoordinates, - ) - return None, None, None, None - - for pid in frame.players_data: - coords = frame.players_data[pid].coordinates - try: - if pid.positions.at_start() in ["Goalkeeper", "GK", "TW"]: - player.is_gk = True - except KeyError: # catching Kloppy key error for empty TimeContainer - pass - - if not pid in next_frame.players_data: - continue - - next_coords = next_frame.players_data[pid].coordinates - - if coords is not None: - player = DefaultPlayer( - x1=coords.x, - x2=next_coords.x, - y1=coords.y, - y2=next_coords.y, - is_visible=True, - fps=self.fps, - ) - - if pid.team.ground == Ground.HOME: - self.home_players.append(player) - elif pid.team.ground == Ground.AWAY: - self.away_players.append(player) - else: - continue - - if not self.home_players or not self.away_players: - return - - if pad_n_players: - for _ in range(0, pad_n_players - len(self.home_players)): - self.home_players.append(DefaultPlayer(fps=self.fps)) - for _ in range(0, pad_n_players - len(self.away_players)): - self.away_players.append(DefaultPlayer(fps=self.fps)) - - if isinstance(frame.ball_coordinates, Point): - z1, z2 = 0.0, 0.0 - elif isinstance(frame.ball_coordinates, Point3D): - z1, z2 = frame.ball_coordinates.z, next_frame.ball_coordinates.z - - self.ball = DefaultBall( - fps=self.fps, - x1=frame.ball_coordinates.x, - y1=frame.ball_coordinates.y, - z1=z1, - x2=next_frame.ball_coordinates.x, - y2=next_frame.ball_coordinates.y, - z2=z2, - ) - - self.attacking_team = ( - None - if frame.ball_owning_team is None - else frame.ball_owning_team.ground - ) - if infer_ball_ownership: - if not self.attacking_team: - self._set_attacking_team(threshold=ball_carrier_treshold) - else: - self._set_ball_carrier_idx(threshold=ball_carrier_treshold) - - attacking_direction = ( - AttackingDirection.LTR - if ( - (orientation == Orientation.BALL_OWNING_TEAM) - or ( - orientation == Orientation.STATIC_HOME_AWAY - and self.attacking_team == Ground.HOME - ) - ) - else AttackingDirection.NOT_SET - ) - - if fix_orientation_ltr and self.attacking_team == Ground.AWAY: - self.home_players = [p.invert_position() for p in self.home_players] - self.away_players = [p.invert_position() for p in self.away_players] - self.ball = self.ball.invert_position() - attacking_direction = AttackingDirection.LTR - - if infer_goalkeepers and attacking_direction == AttackingDirection.LTR: - if not any([p.is_gk for p in self.home_players]): - self.home_players = self._set_goalkeeper( - self.home_players, - func=( - np.argmin - if self.attacking_team == Ground.HOME - else np.argmax - ), - ) - if not any([p.is_gk for p in self.away_players]): - self.away_players = self._set_goalkeeper( - self.away_players, - func=( - np.argmin - if self.attacking_team == Ground.AWAY - else np.argmax - ), - ) - else: - raise NotImplementedError( - """'data' dtype is not supported. Make sure it's a kloppy 'Frame'""" - ) From b781edc61d6d501459d6f70f2899466ef8205020 Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Fri, 23 May 2025 10:39:48 +0200 Subject: [PATCH 5/6] aligning column names --- tests/test_american_football.py | 24 ++++---- unravel/american_football/dataset/dataset.py | 56 ++++++++++++------- unravel/american_football/dataset/objects.py | 23 ++++---- .../graphs/graph_converter.py | 22 ++++---- unravel/soccer/dataset/objects.py | 22 ++++---- 5 files changed, 85 insertions(+), 62 deletions(-) diff --git a/tests/test_american_football.py b/tests/test_american_football.py index c9552ea6..9da6951f 100644 --- a/tests/test_american_football.py +++ b/tests/test_american_football.py @@ -55,8 +55,8 @@ def default_dataset(self, coordinates: str, players: str, plays: str): max_player_acceleration=10.0, max_ball_acceleration=10.0, ) - bdb_dataset.add_graph_ids(by=["gameId", "playId"]) - bdb_dataset.add_dummy_labels(by=["gameId", "playId", "frameId"]) + bdb_dataset.add_graph_ids(by=["game_id", "play_id"]) + bdb_dataset.add_dummy_labels(by=["game_id", "play_id", "frame_id"]) return bdb_dataset @pytest.fixture @@ -70,8 +70,8 @@ def non_default_dataset(self, coordinates: str, players: str, plays: str): max_player_acceleration=11.0, max_ball_acceleration=12.0, ) - bdb_dataset.add_graph_ids(by=["gameId", "playId"]) - bdb_dataset.add_dummy_labels(by=["gameId", "playId", "frameId"]) + bdb_dataset.add_graph_ids(by=["game_id", "play_id"]) + bdb_dataset.add_dummy_labels(by=["game_id", "play_id", "frame_id"]) return bdb_dataset @pytest.fixture @@ -280,24 +280,24 @@ def test_dataset_loader(self, default_dataset: tuple): row_10 = data[10].to_dict() - assert row_10["gameId"][0] == 2021091300 - assert row_10["playId"][0] == 4845 - assert row_10["nflId"][0] == 33131 - assert row_10["frameId"][0] == 11 + assert row_10["game_id"][0] == 2021091300 + assert row_10["play_id"][0] == 4845 + assert row_10["id"][0] == 33131 + assert row_10["frame_id"][0] == 11 assert row_10["time"][0] == datetime(2021, 9, 14, 3, 54, 18, 700000) assert row_10["jerseyNumber"][0] == 93 - assert row_10["team"][0] == "BAL" + assert row_10["team_id"][0] == "BAL" assert row_10["playDirection"][0] == "left" assert row_10["x"][0] == pytest.approx(19.770000000000003, rel=1e-9) assert row_10["y"][0] == pytest.approx(4.919999999999998, rel=1e-9) - assert row_10["s"][0] == pytest.approx(1.5, rel=1e-9) + assert row_10["v"][0] == pytest.approx(1.5, rel=1e-9) assert row_10["a"][0] == pytest.approx(2.13, rel=1e-9) assert row_10["dis"][0] == pytest.approx(0.19, rel=1e-9) assert row_10["o"][0] == pytest.approx(-1.3828243663551074, rel=1e-9) assert row_10["dir"][0] == pytest.approx(-2.176600110162128, rel=1e-9) assert row_10["event"][0] == None - assert row_10["officialPosition"][0] == "DE" - assert row_10["possessionTeam"][0] == "LV" + assert row_10["position_name"][0] == "DE" + assert row_10["ball_owning_team_id"][0] == "LV" assert row_10["graph_id"][0] == "2021091300-4845" assert "label" in data.columns diff --git a/unravel/american_football/dataset/dataset.py b/unravel/american_football/dataset/dataset.py index 30a038e9..12ecd870 100644 --- a/unravel/american_football/dataset/dataset.py +++ b/unravel/american_football/dataset/dataset.py @@ -87,8 +87,9 @@ def load(self): play_direction = "left" if "club" in df.collect_schema().names(): - df = df.with_columns(pl.col(Column.CLUB).alias(Column.TEAM)) - df = df.drop(Column.CLUB) + df = df.rename({"club": Column.TEAM_ID}) + elif "team" in df.collect_schema().names(): + df = df.rename({"team": Column.TEAM_ID}) if self._orient_ball_owning: df = ( @@ -130,10 +131,10 @@ def load(self): .otherwise(pl.col(Column.Y)) .alias(Column.Y), # set "football" to nflId -9999 for ordering purposes - pl.when(pl.col(Column.TEAM) == Constant.BALL) + pl.when(pl.col(Column.TEAM_ID) == Constant.BALL) .then(-9999.9) - .otherwise(pl.col(Column.OBJECT_ID)) - .alias(Column.OBJECT_ID), + .otherwise(pl.col("nflId")) + .alias("nflId"), ] ) .with_columns( @@ -141,7 +142,7 @@ def load(self): pl.lit(play_direction).alias("playDirection"), ] ) - .filter((pl.col(Column.FRAME_ID) % sample) == 0) + .filter((pl.col("frameId") % sample) == 0) ).collect() else: raise NotImplementedError( @@ -157,15 +158,12 @@ def load(self): ignore_errors=True, ) if "position" in players.columns: - players = players.with_columns( - pl.col("position").alias(Column.OFFICIAL_POSITION) - ) - players = players.drop("position") + players = players.rename({"position": Column.POSITION_NAME}) + elif "officialPosition" in players.columns: + players = players.rename({"officialPosition": Column.POSITION_NAME}) players = players.with_columns( - pl.col(Column.OBJECT_ID) - .cast(pl.Float64, strict=False) - .alias(Column.OBJECT_ID) + pl.col("nflId").cast(pl.Float64, strict=False).alias("nflId") ) players = self._convert_weight_height_to_metric(df=players) @@ -175,27 +173,45 @@ def load(self): encoding="utf8", null_values=["NA", "NULL", ""], try_parse_dates=True, + ).rename( + { + "gameId": Column.GAME_ID, + "playId": Column.PLAY_ID, + "possessionTeam": Column.BALL_OWNING_TEAM_ID, + } ) df = df.join( ( players.select( [ - Column.OBJECT_ID, - Column.OFFICIAL_POSITION, + "nflId", + Column.POSITION_NAME, Column.HEIGHT_CM, Column.WEIGHT_KG, ] ) ), - on=Column.OBJECT_ID, + on="nflId", how="left", ) + + df = df.rename( + { + "nflId": Column.OBJECT_ID, + "gameId": Column.GAME_ID, + "frameId": Column.FRAME_ID, + "playId": Column.PLAY_ID, + "s": Column.SPEED, + } + ) + df = df.join( - (plays.select(Group.BY_PLAY_POSSESSION_TEAM)), + (plays.select(Group.BY_PLAY_BALL_OWNING)), on=[Column.GAME_ID, Column.PLAY_ID], how="left", ) + self.data = df # update pitch dimensions to how it looks after loading @@ -211,12 +227,14 @@ def load(self): return self.data, self.settings def add_dummy_labels( - self, by: List[str] = ["gameId", "playId", "frameId"] + self, by: List[str] = [Column.GAME_ID, Column.PLAY_ID, Column.FRAME_ID] ) -> pl.DataFrame: self.data = add_dummy_label_column(self.data, by, self._label_column) return self.data - def add_graph_ids(self, by: List[str] = ["gameId", "playId"]) -> pl.DataFrame: + def add_graph_ids( + self, by: List[str] = [Column.GAME_ID, Column.PLAY_ID] + ) -> pl.DataFrame: self.data = add_graph_id_column(self.data, by, self._graph_id_column) return self.data diff --git a/unravel/american_football/dataset/objects.py b/unravel/american_football/dataset/objects.py index 379b40e1..12b3f047 100644 --- a/unravel/american_football/dataset/objects.py +++ b/unravel/american_football/dataset/objects.py @@ -4,27 +4,30 @@ class Constant: class Column: - OBJECT_ID = "nflId" + OBJECT_ID = "id" - GAME_ID = "gameId" - FRAME_ID = "frameId" - PLAY_ID = "playId" + GAME_ID = "game_id" + FRAME_ID = "frame_id" + PLAY_ID = "play_id" X = "x" Y = "y" + SPEED = "v" + ACCELERATION = "a" - SPEED = "s" + + TEAM_ID = "team_id" + POSITION_NAME = "position_name" + + BALL_OWNING_TEAM_ID = "ball_owning_team_id" + ORIENTATION = "o" DIRECTION = "dir" - TEAM = "team" - CLUB = "club" - OFFICIAL_POSITION = "officialPosition" - POSSESSION_TEAM = "possessionTeam" HEIGHT_CM = "height_cm" WEIGHT_KG = "weight_kg" class Group: BY_FRAME = [Column.GAME_ID, Column.PLAY_ID, Column.FRAME_ID] - BY_PLAY_POSSESSION_TEAM = [Column.GAME_ID, Column.PLAY_ID, Column.POSSESSION_TEAM] + BY_PLAY_BALL_OWNING = [Column.GAME_ID, Column.PLAY_ID, Column.BALL_OWNING_TEAM_ID] diff --git a/unravel/american_football/graphs/graph_converter.py b/unravel/american_football/graphs/graph_converter.py index 327f94bc..e1e2ff3e 100644 --- a/unravel/american_football/graphs/graph_converter.py +++ b/unravel/american_football/graphs/graph_converter.py @@ -107,8 +107,8 @@ def __remove_with_missing_football(): .agg( [ pl.len().alias("size"), # Count total rows in each group - pl.col(Column.TEAM) - .filter(pl.col(Column.TEAM) == Constant.BALL) + pl.col(Column.TEAM_ID) + .filter(pl.col(Column.TEAM_ID) == Constant.BALL) .count() .alias("football_count"), # Count rows where team == 'football' ] @@ -179,9 +179,9 @@ def _exprs_variables(self): Column.ACCELERATION, Column.ORIENTATION, Column.DIRECTION, - Column.TEAM, - Column.OFFICIAL_POSITION, - Column.POSSESSION_TEAM, + Column.TEAM_ID, + Column.POSITION_NAME, + Column.BALL_OWNING_TEAM_ID, Column.HEIGHT_CM, Column.WEIGHT_KG, self.graph_id_column, @@ -228,8 +228,8 @@ def _compute(self, args: List[pl.Series]) -> dict: ) adjacency_matrix = compute_adjacency_matrix( - team=d[Column.TEAM], - possession_team=d[Column.POSSESSION_TEAM], + team=d[Column.TEAM_ID], + possession_team=d[Column.BALL_OWNING_TEAM_ID], settings=self.settings, ) edge_features = compute_edge_features( @@ -239,7 +239,7 @@ def _compute(self, args: List[pl.Series]) -> dict: a=d[Column.ACCELERATION], dir=d[Column.DIRECTION], o=d[Column.ORIENTATION], - team=d[Column.TEAM], + team=d[Column.TEAM_ID], settings=self.settings, ) node_features = compute_node_features( @@ -249,9 +249,9 @@ def _compute(self, args: List[pl.Series]) -> dict: a=d[Column.ACCELERATION], dir=d[Column.DIRECTION], o=d[Column.ORIENTATION], - team=d[Column.TEAM], - official_position=d[Column.OFFICIAL_POSITION], - possession_team=d[Column.POSSESSION_TEAM], + team=d[Column.TEAM_ID], + official_position=d[Column.POSITION_NAME], + possession_team=d[Column.BALL_OWNING_TEAM_ID], height=d[Column.HEIGHT_CM], weight=d[Column.WEIGHT_KG], graph_features=graph_features, diff --git a/unravel/soccer/dataset/objects.py b/unravel/soccer/dataset/objects.py index ae20acf7..1c7bbcc2 100644 --- a/unravel/soccer/dataset/objects.py +++ b/unravel/soccer/dataset/objects.py @@ -3,17 +3,10 @@ class Constant: class Column: - BALL_OWNING_TEAM_ID = "ball_owning_team_id" - BALL_OWNING_PLAYER_ID = "ball_owning_player_id" - IS_BALL_CARRIER = "is_ball_carrier" - PERIOD_ID = "period_id" - TIMESTAMP = "timestamp" - BALL_STATE = "ball_state" - FRAME_ID = "frame_id" - GAME_ID = "game_id" - TEAM_ID = "team_id" OBJECT_ID = "id" - POSITION_NAME = "position_name" + + GAME_ID = "game_id" + FRAME_ID = "frame_id" X = "x" Y = "y" @@ -29,6 +22,15 @@ class Column: AY = "ay" AZ = "az" + BALL_OWNING_TEAM_ID = "ball_owning_team_id" + BALL_OWNING_PLAYER_ID = "ball_owning_player_id" + IS_BALL_CARRIER = "is_ball_carrier" + PERIOD_ID = "period_id" + TIMESTAMP = "timestamp" + BALL_STATE = "ball_state" + TEAM_ID = "team_id" + POSITION_NAME = "position_name" + class Group: BY_FRAME = [Column.GAME_ID, Column.PERIOD_ID, Column.FRAME_ID] From aa98f7e97c4f84ca0b2566582b4107256e21fbef Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Fri, 23 May 2025 11:28:17 +0200 Subject: [PATCH 6/6] fix --- tests/test_spektral.py | 4 ++-- unravel/utils/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_spektral.py b/tests/test_spektral.py index a4400b36..f417dfa8 100644 --- a/tests/test_spektral.py +++ b/tests/test_spektral.py @@ -48,8 +48,8 @@ def bdb_dataset(self, coordinates: str, players: str, plays: str): max_player_acceleration=10.0, max_ball_acceleration=10.0, ) - bdb_dataset.add_graph_ids(by=["gameId", "playId"]) - bdb_dataset.add_dummy_labels(by=["gameId", "playId", "frameId"]) + bdb_dataset.add_graph_ids(by=["game_id", "play_id"]) + bdb_dataset.add_dummy_labels(by=["game_id", "play_id", "frame_id"]) return bdb_dataset @pytest.fixture diff --git a/unravel/utils/utils.py b/unravel/utils/utils.py index 61c07a0e..bac1d4c9 100644 --- a/unravel/utils/utils.py +++ b/unravel/utils/utils.py @@ -72,7 +72,7 @@ def add_dummy_label_column( def add_graph_id_column( dataset: pl.DataFrame, - by: List[str] = ["gameId", "playId"], + by: List[str] = ["game_id", "play_id"], column_name: str = "graph_id", ): return dataset.with_columns([pl.concat_str(by, separator="-").alias(column_name)])