From 99883aa4db1e90d0e052229d964234bb9d705ada Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Tue, 20 May 2025 15:23:31 +0200 Subject: [PATCH 1/2] sort --- unravel/soccer/graphs/graph_converter_pl.py | 32 ++++++++++++--------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/unravel/soccer/graphs/graph_converter_pl.py b/unravel/soccer/graphs/graph_converter_pl.py index fe49461a..80186f14 100644 --- a/unravel/soccer/graphs/graph_converter_pl.py +++ b/unravel/soccer/graphs/graph_converter_pl.py @@ -153,6 +153,17 @@ def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]): "Function has an incorrect feature type edge features should be 'edge', node features should be 'node'. " ) + @staticmethod + def _sort(df): + sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - ( + (pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID)) + & (pl.col(Column.TEAM_ID) != Constant.BALL) + ).cast(int) + + df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)]) + df = df.sort(Group.BY_FRAME + [Column.OBJECT_ID]) + return df + def _shuffle(self): if isinstance(self.settings.random_seed, int): self.dataset = self.dataset.sample( @@ -161,16 +172,7 @@ def _shuffle(self): elif self.settings.random_seed == True: self.dataset = self.dataset.sample(fraction=1.0) else: - - sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - ( - (pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID)) - & (pl.col(Column.TEAM_ID) != Constant.BALL) - ).cast(int) - - self.dataset = self.dataset.sort( - [*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)] - ) - self.dataset = self.dataset.sort(Group.BY_FRAME + [Column.OBJECT_ID]) + self.dataset = self._sort(self.dataset) def _remove_incomplete_frames(self) -> pl.DataFrame: df = self.dataset @@ -707,6 +709,7 @@ def plot( team_color_a: str = "#CD0E61", team_color_b: str = "#0066CC", ball_color: str = "black", + sort: bool = True, color_by: Literal["ball_owning", "static_home_away"] = "ball_owning", ): """ @@ -1082,10 +1085,13 @@ def frame_plot(self, frame_data): if generate_video: writer = animation.FFMpegWriter(fps=fps, bitrate=1800) + if sort: + df = self._sort(df) + with writer.saving(self._fig, file_path, dpi=300): - for group_id, frame_data in df.sort( - Group.BY_FRAME + [Column.OBJECT_ID] - ).group_by(Group.BY_FRAME, maintain_order=True): + for group_id, frame_data in df.group_by( + Group.BY_FRAME, maintain_order=True + ): self._fig.clear() frame_plot(self, frame_data) writer.grab_frame() From 5f3467a662ca4302ec5b39160266ca9847a75d21 Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Tue, 20 May 2025 15:27:25 +0200 Subject: [PATCH 2/2] fix sort --- unravel/soccer/graphs/graph_converter_pl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unravel/soccer/graphs/graph_converter_pl.py b/unravel/soccer/graphs/graph_converter_pl.py index 80186f14..f355b413 100644 --- a/unravel/soccer/graphs/graph_converter_pl.py +++ b/unravel/soccer/graphs/graph_converter_pl.py @@ -1082,12 +1082,12 @@ def frame_plot(self, frame_data): self._fig = plt.figure(figsize=(25, 18)) self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05) + if sort: + df = self._sort(df) + if generate_video: writer = animation.FFMpegWriter(fps=fps, bitrate=1800) - if sort: - df = self._sort(df) - with writer.saving(self._fig, file_path, dpi=300): for group_id, frame_data in df.group_by( Group.BY_FRAME, maintain_order=True