diff --git a/tests/files/plot/test-1.mp4 b/tests/files/plot/test-1.mp4 index 40f0ae3a..1b3b3af7 100644 Binary files a/tests/files/plot/test-1.mp4 and b/tests/files/plot/test-1.mp4 differ diff --git a/tests/files/plot/test-no-extension.png b/tests/files/plot/test-no-extension.png index c58de2f5..13c820fd 100644 Binary files a/tests/files/plot/test-no-extension.png and b/tests/files/plot/test-no-extension.png differ diff --git a/tests/files/plot/test-png.png b/tests/files/plot/test-png.png index c58de2f5..13c820fd 100644 Binary files a/tests/files/plot/test-png.png and b/tests/files/plot/test-png.png differ diff --git a/unravel/soccer/graphs/graph_converter.py b/unravel/soccer/graphs/graph_converter.py index c98c3634..04ad45db 100644 --- a/unravel/soccer/graphs/graph_converter.py +++ b/unravel/soccer/graphs/graph_converter.py @@ -650,6 +650,17 @@ def _convert(self): .drop("result_dict") ) + def get_players_by_team_id(self, team_id): + return [ + player for player in self.settings.players if player["team_id"] == team_id + ] + + def get_player_by_id(self, player_id): + for player in self.settings.players: + if player["player_id"] == player_id: + return player + return None + def plot( self, file_path: str, @@ -802,6 +813,15 @@ def plot( def plot_graph(): import matplotlib.pyplot as plt + labels = [ + ( + self.get_player_by_id(pid)["jersey_no"] + if pid != Constant.BALL + else Constant.BALL + ) + for pid in self._graph.object_ids + ] + # Plot node features in top-left ax1 = self._fig.add_subplot(self._gs[0, 0]) ax1.imshow(self._graph.x, aspect="auto", cmap="YlOrRd") @@ -810,7 +830,7 @@ def plot_graph(): # Set y labels to integers num_rows = self._graph.x.shape[0] ax1.set_yticks(range(num_rows)) - ax1.set_yticklabels([str(i) for i in range(num_rows)]) + ax1.set_yticklabels(labels) node_feature_yticklabels = feature_ticklabels(self._node_feature_dims) ax1.xaxis.set_ticks_position("top") @@ -827,10 +847,10 @@ def plot_graph(): num_cols_a = self._graph.a.toarray().shape[1] ax2.set_yticks(range(num_rows_a)) - ax2.set_yticklabels([str(i) for i in range(num_rows_a)]) + ax2.set_yticklabels(labels) ax2.xaxis.set_ticks_position("top") ax2.set_xticks(range(num_cols_a)) - ax2.set_xticklabels([str(i) for i in range(num_cols_a)]) + ax2.set_xticklabels(labels) # Plot Edge Features on the right (spanning both rows) ax3 = self._fig.add_subplot(self._gs[:, 1]) @@ -850,18 +870,19 @@ def plot_graph(): ax3.set_yticks(range(num_rows_e)) ax3.set_yticklabels(list(ball_carrier_edge_idx[0]), fontsize=18) + ball_carrier_edge_idxs = list(ball_carrier_edge_idx[0]) ax3.set_xlabel(f"Edge Features {self._graph.e.shape}") - labels = ax3.get_yticklabels() + ax3_labels = ax3.get_yticklabels() if self._ball_carrier_idx in ball_carrier_edge_idx[0]: idx_position = list(ball_carrier_edge_idx[0]).index( self._ball_carrier_idx ) # Modify just that specific label - labels[idx_position].set_color(self._ball_carrier_color) - labels[idx_position].set_fontweight("bold") + ax3_labels[idx_position].set_color(self._ball_carrier_color) + ax3_labels[idx_position].set_fontweight("bold") # Set the modified labels back - ax3.set_yticklabels(labels) + ax3.set_yticklabels([labels[i] for i in ball_carrier_edge_idxs]) # Set x labels to edge function names at the top, rotated 45 degrees edge_feature_xticklabels = feature_ticklabels(self._edge_feature_dims) @@ -980,7 +1001,11 @@ def player_and_ball(frame_data, ax): text = ax.text( x + (-1.2 if is_ball else 0.0), y + (-1.2 if is_ball else 0.0), - i, + ( + self.get_player_by_id(r[Column.OBJECT_ID])["jersey_no"] + if r[Column.OBJECT_ID] != Constant.BALL + else Constant.BALL + ), color=self._ball_color if is_ball else color, fontsize=12, ha="center", @@ -1035,13 +1060,15 @@ def timestamp_to_gameclock(timestamp, period_id): features["e"], features["e_shape_0"], features["e_shape_1"] ) y = np.asarray([features[self.label_column]]) - frame_id = features["frame_id"] self._graph = Graph( a=a, x=x, e=e, y=y, + frame_id=features["frame_id"], + object_ids=frame_data[Column.OBJECT_ID], + ball_owning_team_id=frame_data[Column.BALL_OWNING_TEAM_ID][0], ) self._ball_carrier_idx = np.where(