diff --git a/unravel/soccer/graphs/graph_converter.py b/unravel/soccer/graphs/graph_converter.py index 7f4abb07..5d006238 100644 --- a/unravel/soccer/graphs/graph_converter.py +++ b/unravel/soccer/graphs/graph_converter.py @@ -481,8 +481,18 @@ def __add_additional_kwargs(self, d): d["is_gk"] = np.where( d[Column.POSITION_NAME] == self.settings.goalkeeper_id, True, False ) - d["position"] = np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1) - d["velocity"] = np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1) + d["position"] = np.nan_to_num( + np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1), + nan=1e-10, + posinf=1e3, + neginf=-1e3, + ) + d["velocity"] = np.nan_to_num( + np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1), + nan=1e-10, + posinf=1e3, + neginf=-1e3, + ) if len(np.where(d["team_id"] == d["ball_id"])[0]) >= 1: ball_index = np.where(d["team_id"] == d["ball_id"])[0]