diff --git a/unravel/soccer/graphs/graph_converter_pl.py b/unravel/soccer/graphs/graph_converter_pl.py index fe49461a..f355b413 100644 --- a/unravel/soccer/graphs/graph_converter_pl.py +++ b/unravel/soccer/graphs/graph_converter_pl.py @@ -153,6 +153,17 @@ def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]): "Function has an incorrect feature type edge features should be 'edge', node features should be 'node'. " ) + @staticmethod + def _sort(df): + sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - ( + (pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID)) + & (pl.col(Column.TEAM_ID) != Constant.BALL) + ).cast(int) + + df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)]) + df = df.sort(Group.BY_FRAME + [Column.OBJECT_ID]) + return df + def _shuffle(self): if isinstance(self.settings.random_seed, int): self.dataset = self.dataset.sample( @@ -161,16 +172,7 @@ def _shuffle(self): elif self.settings.random_seed == True: self.dataset = self.dataset.sample(fraction=1.0) else: - - sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - ( - (pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID)) - & (pl.col(Column.TEAM_ID) != Constant.BALL) - ).cast(int) - - self.dataset = self.dataset.sort( - [*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)] - ) - self.dataset = self.dataset.sort(Group.BY_FRAME + [Column.OBJECT_ID]) + self.dataset = self._sort(self.dataset) def _remove_incomplete_frames(self) -> pl.DataFrame: df = self.dataset @@ -707,6 +709,7 @@ def plot( team_color_a: str = "#CD0E61", team_color_b: str = "#0066CC", ball_color: str = "black", + sort: bool = True, color_by: Literal["ball_owning", "static_home_away"] = "ball_owning", ): """ @@ -1079,13 +1082,16 @@ def frame_plot(self, frame_data): self._fig = plt.figure(figsize=(25, 18)) self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05) + if sort: + df = self._sort(df) + if generate_video: writer = animation.FFMpegWriter(fps=fps, bitrate=1800) with writer.saving(self._fig, file_path, dpi=300): - for group_id, frame_data in df.sort( - Group.BY_FRAME + [Column.OBJECT_ID] - ).group_by(Group.BY_FRAME, maintain_order=True): + for group_id, frame_data in df.group_by( + Group.BY_FRAME, maintain_order=True + ): self._fig.clear() frame_plot(self, frame_data) writer.grab_frame()