Skip to content
Merged

sort #39

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions unravel/soccer/graphs/graph_converter_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ def _verify_feature_funcs(self, funcs, feature_type: Literal["edge", "node"]):
"Function has an incorrect feature type edge features should be 'edge', node features should be 'node'. "
)

@staticmethod
def _sort(df):
sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - (
(pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID))
& (pl.col(Column.TEAM_ID) != Constant.BALL)
).cast(int)

df = df.sort([*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)])
df = df.sort(Group.BY_FRAME + [Column.OBJECT_ID])
return df

def _shuffle(self):
if isinstance(self.settings.random_seed, int):
self.dataset = self.dataset.sample(
Expand All @@ -161,16 +172,7 @@ def _shuffle(self):
elif self.settings.random_seed == True:
self.dataset = self.dataset.sample(fraction=1.0)
else:

sort_expr = (pl.col(Column.TEAM_ID) == Constant.BALL).cast(int) * 2 - (
(pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID))
& (pl.col(Column.TEAM_ID) != Constant.BALL)
).cast(int)

self.dataset = self.dataset.sort(
[*Group.BY_FRAME, sort_expr, pl.col(Column.OBJECT_ID)]
)
self.dataset = self.dataset.sort(Group.BY_FRAME + [Column.OBJECT_ID])
self.dataset = self._sort(self.dataset)

def _remove_incomplete_frames(self) -> pl.DataFrame:
df = self.dataset
Expand Down Expand Up @@ -707,6 +709,7 @@ def plot(
team_color_a: str = "#CD0E61",
team_color_b: str = "#0066CC",
ball_color: str = "black",
sort: bool = True,
color_by: Literal["ball_owning", "static_home_away"] = "ball_owning",
):
"""
Expand Down Expand Up @@ -1079,13 +1082,16 @@ def frame_plot(self, frame_data):
self._fig = plt.figure(figsize=(25, 18))
self._fig.subplots_adjust(left=0.06, right=1.0, bottom=0.05)

if sort:
df = self._sort(df)

if generate_video:
writer = animation.FFMpegWriter(fps=fps, bitrate=1800)

with writer.saving(self._fig, file_path, dpi=300):
for group_id, frame_data in df.sort(
Group.BY_FRAME + [Column.OBJECT_ID]
).group_by(Group.BY_FRAME, maintain_order=True):
for group_id, frame_data in df.group_by(
Group.BY_FRAME, maintain_order=True
):
self._fig.clear()
frame_plot(self, frame_data)
writer.grab_frame()
Expand Down