From 2634a2e00acb1f9660673e52b5ca4cce685ad1f0 Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Fri, 23 May 2025 15:10:45 +0200 Subject: [PATCH] add epsilon for nans --- unravel/soccer/graphs/graph_converter.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/unravel/soccer/graphs/graph_converter.py b/unravel/soccer/graphs/graph_converter.py index 7f4abb07..5d006238 100644 --- a/unravel/soccer/graphs/graph_converter.py +++ b/unravel/soccer/graphs/graph_converter.py @@ -481,8 +481,18 @@ def __add_additional_kwargs(self, d): d["is_gk"] = np.where( d[Column.POSITION_NAME] == self.settings.goalkeeper_id, True, False ) - d["position"] = np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1) - d["velocity"] = np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1) + d["position"] = np.nan_to_num( + np.stack((d[Column.X], d[Column.Y], d[Column.Z]), axis=-1), + nan=1e-10, + posinf=1e3, + neginf=-1e3, + ) + d["velocity"] = np.nan_to_num( + np.stack((d[Column.VX], d[Column.VY], d[Column.VZ]), axis=-1), + nan=1e-10, + posinf=1e3, + neginf=-1e3, + ) if len(np.where(d["team_id"] == d["ball_id"])[0]) >= 1: ball_index = np.where(d["team_id"] == d["ball_id"])[0]