diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index b6599650..1b4f0ad4 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -12,14 +12,67 @@ jobs: test: strategy: matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.11'] + include: + # Python 3.11 with Spektral/TensorFlow + - os: ubuntu-latest + python-version: '3.11' + test-deps: 'test-py311' + pytest-markers: '' + - os: windows-latest + python-version: '3.11' + test-deps: 'test-py311' + pytest-markers: '' + - os: macos-latest + python-version: '3.11' + test-deps: 'test-py311' + pytest-markers: '' + + # Python 3.12 with PyTorch only + - os: ubuntu-latest + python-version: '3.12' + test-deps: 'test-torch' + pytest-markers: '-m "not spektral"' + - os: windows-latest + python-version: '3.12' + test-deps: 'test-torch' + pytest-markers: '-m "not spektral"' + - os: macos-latest + python-version: '3.12' + test-deps: 'test-torch' + pytest-markers: '-m "not spektral"' + + # Python 3.13 with PyTorch only + - os: ubuntu-latest + python-version: '3.13' + test-deps: 'test-torch' + pytest-markers: '-m "not spektral"' + - os: windows-latest + python-version: '3.13' + test-deps: 'test-torch' + pytest-markers: '-m "not spektral"' + - os: macos-latest + python-version: '3.13' + test-deps: 'test-torch' + pytest-markers: '-m "not spektral"' runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v2 + - name: Free up disk space (Linux only) + if: runner.os == 'Linux' + run: | + echo "Disk space before cleanup:" + df -h + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /usr/local/share/boost + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + sudo docker image prune --all --force + echo "Disk space after cleanup:" + df -h + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -27,46 +80,24 @@ jobs: # Linux/macOS - Install dependencies - name: Install dependencies (Linux/macOS) - if: runner.os != 'Windows' # Only run on Linux/macOS + if: runner.os != 'Windows' env: - PYTHONIOENCODING: utf-8 # Ensure Python uses UTF-8 encoding + PYTHONIOENCODING: utf-8 run: | python -m pip install --upgrade pip - python -m pip install -e .[test] - shell: bash - - # Install FFmpeg on Ubuntu - - name: Install FFmpeg (Ubuntu) - if: runner.os == 'Linux' - run: | - sudo apt-get update - sudo apt-get install -y ffmpeg + python -m pip install -e .[${{ matrix.test-deps }}] shell: bash - # Install FFmpeg on macOS - - name: Install FFmpeg (macOS) - if: runner.os == 'macOS' - run: | - brew install ffmpeg - shell: bash - - # Install FFmpeg on Windows - - name: Install FFmpeg (Windows) - if: runner.os == 'Windows' - run: | - choco install ffmpeg -y - shell: pwsh - # Windows - Install dependencies - name: Install dependencies (Windows) - if: runner.os == 'Windows' # Only run on Windows + if: runner.os == 'Windows' env: - PYTHONIOENCODING: utf-8 # Ensure Python uses UTF-8 encoding - PYTHONUTF8: 1 # Force Python to use UTF-8 mode + PYTHONIOENCODING: utf-8 + PYTHONUTF8: 1 run: | - chcp 65001 # Change code page to UTF-8 + chcp 65001 python -m pip install --upgrade pip setuptools - python -m pip install -e .[test] + python -m pip install -e .[${{ matrix.test-deps }}] shell: pwsh - name: Code formatting @@ -76,4 +107,4 @@ jobs: - name: Test with pytest run: | - pytest --color=yes + pytest --color=yes ${{ matrix.pytest-markers }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5b68b22f..577c8727 100644 --- a/.gitignore +++ b/.gitignore @@ -195,4 +195,6 @@ examples/models/* *.mp4 *.png *.json -diffs/ \ No newline at end of file +diffs/ + +lightning_logs/ \ No newline at end of file diff --git a/README.md b/README.md index 20c11e28..3aabe270 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ bdb = BigDataBowlDataset( ### **Graph Neural Networks** -⚽🏈 Convert **[Polars Dataframes](#polars-dataframes)** into [Graphs](examples/graphs_faq.md) to train **graph neural networks**. These [Graphs](examples/graphs_faq.md) can be used with [**Spektral**](https://github.com/danielegrattarola/spektral) - a flexible framework for training graph neural networks. +⚽🏈 Convert **[Polars Dataframes](#polars-dataframes)** into [Graphs](examples/graphs_faq.md) to train **graph neural networks**. These [Graphs](examples/graphs_faq.md) can be used with [**PyTorch Geometric**](https://pytorch-geometric.readthedocs.io/en/latest/) or [**Spektral**](https://github.com/danielegrattarola/spektral). `unravelsports` allows you to **randomize** and **split** data into train, test and validation sets along matches, sequences or possessions to avoid leakage and improve model quality. And finally, **train**, **validate** and **test** your (custom) Graph model(s) and easily **predict** on new data. ```python @@ -142,9 +142,9 @@ model.fit( πŸŒ€ Quick Start ----- -πŸ“– ⚽ The [**Quick Start Jupyter Notebook**](examples/0_quick_start_guide.ipynb) explains how to convert any positional tracking data from **Kloppy** to **Spektral GNN** in a few easy steps while walking you through the most important features and documentation. +πŸ“– ⚽ The [**Quick Start Jupyter Notebook**](examples/0_quick_start_guide_pyg.ipynb) explains how to convert any positional tracking data from **Kloppy** to **Spektral GNN** in a few easy steps while walking you through the most important features and documentation. -πŸ“– ⚽ The [**Graph Converter Tutorial Jupyter Notebook**](examples/1_kloppy_gnn_train.ipynb) gives an in-depth walkthrough. +πŸ“– ⚽ The [**Graph Converter Tutorial Jupyter Notebook**](examples/1_kloppy_gnn_train_pyg.ipynb) gives an in-depth walkthrough. πŸ“– 🏈 The [**BigDataBowl Converter Tutorial Jupyter Notebook**](examples/2_big_data_bowl_guide.ipynb) gives an guide on how to convert the BigDataBowl data into Graphs. @@ -169,18 +169,6 @@ The easiest way to get started is: pip install unravelsports ``` -⚠️ Due to compatibility issues **unravelsports** currently only works on Python 3.11 with: -``` -spektral==1.20.0 -tensorflow==2.14.0 -keras==2.14.0 -kloppy==3.17.0 -polars==1.2.1 -``` -These dependencies come pre-installed with the package. It is advised to create a [virtual environment](https://virtualenv.pypa.io/en/latest/). - -This package is tested on the latest versions of Ubuntu, MacOS and Windows. - πŸŒ€ Licenses ---- This project is licensed under the [Mozilla Public License Version 2.0 (MPL)](LICENSE), which requires that you include a copy of the license and provide attribution to the original authors. Any modifications you make to the MPL-licensed files must be documented, and the source code for those modifications must be made open-source under the same license. @@ -196,7 +184,7 @@ If you use this repository for any educational purposes, research, project etc., @software{unravelsports2024repository, author = {Bekkers, Joris}, title = {unravelsports}, - version = {0.3.0}, + version = {2.0.0}, year = {2024}, publisher = {GitHub}, url = {https://github.com/unravelsports/unravelsports} diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..16c57da6 --- /dev/null +++ b/conftest.py @@ -0,0 +1,49 @@ +import pytest +import sys +import os + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "spektral: tests that require spektral (Python 3.11 only)" + ) + config.addinivalue_line("markers", "torch: tests that require PyTorch") + config.addinivalue_line( + "markers", "local_only: tests that should only run in local environment" + ) + + +def pytest_collection_modifyitems(config, items): + """Automatically skip tests based on available dependencies and environment""" + + try: + import spektral + + has_spektral = True + except ImportError: + has_spektral = False + + try: + import torch + import torch_geometric + + has_torch = True + except ImportError: + has_torch = False + + # Check if running in CI or non-local environment + is_ci = os.getenv("CI") is not None + + skip_spektral = pytest.mark.skip(reason="Spektral not installed") + skip_torch = pytest.mark.skip(reason="PyTorch/PyG not installed") + skip_local = pytest.mark.skip( + reason="Skipping local-only tests in CI/non-local environment" + ) + + for item in items: + if "spektral" in item.keywords and not has_spektral: + item.add_marker(skip_spektral) + if "torch" in item.keywords and not has_torch: + item.add_marker(skip_torch) + if "local_only" in item.keywords and is_ci: + item.add_marker(skip_local) diff --git a/examples/0_quick_start_guide.ipynb b/examples/0_quick_start_guide.ipynb index 6205bd83..2b59314d 100644 --- a/examples/0_quick_start_guide.ipynb +++ b/examples/0_quick_start_guide.ipynb @@ -6,6 +6,8 @@ "source": [ "## πŸŒ€ Quick Start Guide: It's all starting to unravel!\n", "\n", + "⚠️ It is recommended to use the [PyTorch implementation](0_quick_start_guide_pyg.ipynb) over this Spektral version.\n", + "\n", "In this example we'll run through all the basic features the `unravelsports` package offers for converting a `kloppy` dataset of soccer tracking data into graphs for training binary classification graph neural networks using the `spektral` library.\n", "\n", "This guide will go through the following steps:\n", @@ -104,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +131,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -166,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -180,7 +182,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/jbekkers/PycharmProjects/unravelsports/.venv311/lib/python3.11/site-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n", + "/Users/jbekkers/PycharmProjects/unravelsports/.venv311-test/lib/python3.11/site-packages/keras/src/initializers/initializers.py:120: UserWarning: The initializer GlorotUniform is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.\n", " warnings.warn(\n" ] }, @@ -188,9 +190,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "11/11 [==============================] - 1s 16ms/step - loss: 21.7806 - auc: 0.5278 - binary_accuracy: 0.5419 - val_loss: 5.1682 - val_auc: 0.5000 - val_binary_accuracy: 0.5000\n", + "11/11 [==============================] - 1s 16ms/step - loss: 82.8534 - auc: 0.5290 - binary_accuracy: 0.5375 - val_loss: 5.5782 - val_auc: 0.5308 - val_binary_accuracy: 0.5595\n", "Epoch 2/10\n", - " 1/11 [=>............................] - ETA: 0s - loss: 9.2846 - auc: 0.3651 - binary_accuracy: 0.5000WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 3 batches). You may need to use the repeat() function when building your dataset.\n" + " 1/11 [=>............................] - ETA: 0s - loss: 70.9456 - auc: 0.3438 - binary_accuracy: 0.3438WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 3 batches). You may need to use the repeat() function when building your dataset.\n" ] }, { @@ -204,32 +206,32 @@ "name": "stdout", "output_type": "stream", "text": [ - "11/11 [==============================] - 0s 6ms/step - loss: 4.5155 - auc: 0.5366 - binary_accuracy: 0.5449\n", + "11/11 [==============================] - 0s 6ms/step - loss: 33.6154 - auc: 0.4974 - binary_accuracy: 0.4835\n", "Epoch 3/10\n", - "11/11 [==============================] - 0s 4ms/step - loss: 2.0773 - auc: 0.4515 - binary_accuracy: 0.4731\n", + "11/11 [==============================] - 0s 6ms/step - loss: 14.5440 - auc: 0.4738 - binary_accuracy: 0.4895\n", "Epoch 4/10\n", - "11/11 [==============================] - 0s 5ms/step - loss: 1.1006 - auc: 0.5205 - binary_accuracy: 0.5150\n", + "11/11 [==============================] - 0s 5ms/step - loss: 7.6316 - auc: 0.4957 - binary_accuracy: 0.4985\n", "Epoch 5/10\n", - "11/11 [==============================] - 0s 4ms/step - loss: 0.9159 - auc: 0.4915 - binary_accuracy: 0.5180\n", + "11/11 [==============================] - 0s 5ms/step - loss: 4.4984 - auc: 0.5123 - binary_accuracy: 0.4985\n", "Epoch 6/10\n", - "11/11 [==============================] - 0s 5ms/step - loss: 0.8020 - auc: 0.4873 - binary_accuracy: 0.5060\n", + "11/11 [==============================] - 0s 5ms/step - loss: 3.3299 - auc: 0.5680 - binary_accuracy: 0.5495\n", "Epoch 7/10\n", - "11/11 [==============================] - 0s 4ms/step - loss: 0.8067 - auc: 0.4960 - binary_accuracy: 0.5299\n", + "11/11 [==============================] - 0s 5ms/step - loss: 2.9137 - auc: 0.4771 - binary_accuracy: 0.4775\n", "Epoch 8/10\n", - "11/11 [==============================] - 0s 6ms/step - loss: 0.7808 - auc: 0.5055 - binary_accuracy: 0.5299\n", + "11/11 [==============================] - 0s 5ms/step - loss: 2.1859 - auc: 0.5167 - binary_accuracy: 0.5075\n", "Epoch 9/10\n", - "11/11 [==============================] - 0s 4ms/step - loss: 0.7661 - auc: 0.4937 - binary_accuracy: 0.5060\n", + "11/11 [==============================] - 0s 4ms/step - loss: 2.2885 - auc: 0.4226 - binary_accuracy: 0.4474\n", "Epoch 10/10\n", - "11/11 [==============================] - 0s 5ms/step - loss: 0.7406 - auc: 0.5098 - binary_accuracy: 0.5329\n" + "11/11 [==============================] - 0s 4ms/step - loss: 1.5413 - auc: 0.5234 - binary_accuracy: 0.5195\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -267,14 +269,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "3/3 [==============================] - 0s 6ms/step - loss: 0.7001 - auc: 0.5000 - binary_accuracy: 0.4819\n" + "3/3 [==============================] - 0s 6ms/step - loss: 0.8078 - auc: 0.3817 - binary_accuracy: 0.4819\n" ] } ], @@ -296,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -315,7 +317,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv311", + "display_name": ".venv313-test", "language": "python", "name": "python3" }, @@ -329,7 +331,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.13.2" } }, "nbformat": 4, diff --git a/examples/0_quick_start_guide_pyg.ipynb b/examples/0_quick_start_guide_pyg.ipynb new file mode 100644 index 00000000..f05d16ec --- /dev/null +++ b/examples/0_quick_start_guide_pyg.ipynb @@ -0,0 +1,475 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## πŸŒ€ Quick Start Guide: It's all starting to unravel!\n", + "\n", + "### ‼️ NEW PYTORCH VERSION\n", + "\n", + "In this example we'll run through all the basic features the `unravelsports` package offers for converting a `kloppy` dataset of soccer tracking data into graphs for training binary classification graph neural networks using PyTorch Geometric and PyTorch Lightning.\n", + "\n", + "This guide will go through the following steps:\n", + "\n", + "- [**1. Process Data**](#1-processing-data). We'll show how to load a `kloppy` dataset and convert each individual frame into a single graph. All necessary steps (like setting the correct coordinate system, and left-right normalization) are done under the hood of the converter.\n", + "- [**1.1 Split Data**](#11-split-data).\n", + "- [**2. Initialize Model**](#2-initialize-model). We initialize the built-in binary classification model as presented in [A Graph Neural Network Deep-dive into Successful Counterattacks {A. Sahasrabudhe & J. Bekkers}](https://github.com/USSoccerFederation/ussf_ssac_23_soccer_gnn).\n", + "- [**3. Train Model**](#3-train-model). Using the initialized model we train it on the training set created in step [1.1 Splitting Data](#11-split-data).\n", + "- [**4. Evaluate Model Performance**](#4-evaluate-model-performance). We calculate model performance using the metrics defined in the model.\n", + "- [**5. Predict**](#5-predict). Finally, we apply the trained model to unseen data.\n", + "- [**6. Save & Load Model**](#6-save--load-model). Learn how to save and reload your trained models.\n", + "\n", + "
\n", + "Before we get started it is important to note that the unravelsports library does not have built in functionality to create binary labels, these will need to be supplied by the reader. In this example we use the dummy_labels() functionality that comes with the package. This function creates a single binary label for each frame by randomly assigning it a 0 or 1 value.\n", + "\n", + "When supplying your own labels they need to be in the form of a dictionary (more information on this can be found in the [in-depth Walkthrough](1_kloppy_gnn_train.ipynb)) \n", + "\n", + "\n", + "\n", + "-----\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The first thing is to run `pip install unravelsports` if you haven't already!\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install unravelsports torch torch-geometric pytorch-lightning torchmetrics --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Process Data\n", + "\n", + "1. Load [Kloppy](https://github.com/PySport/kloppy) dataset. \n", + " See [in-depth Tutorial](1_kloppy_gnn_train.ipynb) on how to process multiple match files, and to see an overview of all possible settings.\n", + "2. Convert to Graph format using `SoccerGraphConverter`\n", + "3. Create dataset for easy processing with PyTorch Geometric" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jbekkers/PycharmProjects/unravelsports/.venv311-test/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from unravel.soccer import SoccerGraphConverter, KloppyPolarsDataset\n", + "from unravel.utils import GraphDataset\n", + "\n", + "from kloppy import sportec\n", + "\n", + "# Load Kloppy dataset\n", + "kloppy_dataset = sportec.load_open_tracking_data(only_alive=True, limit=500)\n", + "kloppy_polars_dataset = KloppyPolarsDataset(\n", + " kloppy_dataset=kloppy_dataset,\n", + ")\n", + "kloppy_polars_dataset.add_dummy_labels()\n", + "kloppy_polars_dataset.add_graph_ids(by=[\"frame_id\"])\n", + "\n", + "# Initialize the Graph Converter with dataset\n", + "# Here we use the default settings\n", + "converter = SoccerGraphConverter(dataset=kloppy_polars_dataset)\n", + "\n", + "# Compute the graphs and add them to the GraphDataset\n", + "pyg_graphs = converter.to_pytorch_graphs()\n", + "dataset = GraphDataset(graphs=pyg_graphs, format=\"pyg\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1.1 Split Data\n", + "\n", + "Split the dataset with the built in `split_test_train_validation` method." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train, test, val = dataset.split_test_train_validation(\n", + " split_train=4, split_test=1, split_validation=1, random_seed=43\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Initialize Model\n", + "\n", + "1. Initialize the `PyGLightningCrystalGraphClassifier` with PyTorch Lightning.\n", + "2. Set up callbacks for model checkpointing and early stopping.\n", + "3. Initialize the trainer.\n", + "\n", + "Note: The model settings are chosen to reflect the model used in [A Graph Neural Network Deep-dive into Successful Counterattacks {A. Sahasrabudhe & J. Bekkers}](https://github.com/USSoccerFederation/ussf_ssac_23_soccer_gnn)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: True\n", + "TPU available: False, using: 0 TPU cores\n" + ] + } + ], + "source": [ + "from unravel.classifiers import PyGLightningCrystalGraphClassifier\n", + "import pytorch_lightning as pyl\n", + "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n", + "\n", + "# Initialize the Lightning model\n", + "lit_model = PyGLightningCrystalGraphClassifier(\n", + " n_layers=3, channels=128, drop_out=0.5, n_out=1\n", + ")\n", + "\n", + "# Set up callbacks\n", + "checkpoint_callback = ModelCheckpoint(\n", + " dirpath=\"models/\",\n", + " filename=\"best-model-{epoch:02d}-{val_auc:.2f}\",\n", + " save_top_k=1,\n", + " monitor=\"val_auc\",\n", + " mode=\"max\",\n", + ")\n", + "\n", + "early_stop_callback = EarlyStopping(monitor=\"val_loss\", patience=5, mode=\"min\")\n", + "\n", + "# Initialize trainer\n", + "trainer = pyl.Trainer(\n", + " max_epochs=10,\n", + " accelerator=\"auto\", # Automatically uses GPU if available\n", + " callbacks=[checkpoint_callback, early_stop_callback],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Train Model\n", + "\n", + "1. Create PyTorch Geometric `DataLoader` for training and validation sets.\n", + "2. Train the model using PyTorch Lightning's `trainer.fit()`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params | Mode | FLOPs\n", + "------------------------------------------------------------------------\n", + "0 | model | PyGCrystalGraphClassifier | 328 K | train | 0 \n", + "1 | criterion | BCELoss | 0 | train | 0 \n", + "2 | train_auc | BinaryAUROC | 0 | train | 0 \n", + "3 | train_acc | BinaryAccuracy | 0 | train | 0 \n", + "4 | val_auc | BinaryAUROC | 0 | train | 0 \n", + "5 | val_acc | BinaryAccuracy | 0 | train | 0 \n", + "6 | test_auc | BinaryAUROC | 0 | train | 0 \n", + "7 | test_acc | BinaryAccuracy | 0 | train | 0 \n", + "------------------------------------------------------------------------\n", + "328 K Trainable params\n", + "0 Non-trainable params\n", + "328 K Total params\n", + "1.315 Total estimated model params size (MB)\n", + "27 Modules in train mode\n", + "0 Modules in eval mode\n", + "0 Total Flops\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 8: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 11/11 [00:00<00:00, 42.73it/s, v_num=1, val_loss=0.695, val_auc=0.446, val_acc=0.429, train_loss=0.702, train_auc=0.503, train_acc=0.492]\n" + ] + } + ], + "source": [ + "from torch_geometric.loader import DataLoader\n", + "\n", + "batch_size = 32\n", + "\n", + "# Create data loaders\n", + "loader_tr = DataLoader(train, batch_size=batch_size, shuffle=True)\n", + "loader_va = DataLoader(val, batch_size=batch_size, shuffle=False)\n", + "\n", + "# Train the model\n", + "trainer.fit(lit_model, loader_tr, loader_va)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Evaluate Model Performance\n", + "\n", + "1. Create a PyTorch Geometric `DataLoader` for the test set.\n", + "2. Use `trainer.test()` to evaluate. This automatically uses the metrics defined in the Lightning module.\n", + "\n", + "Note: Our performance is really bad because we're using random labels, very few epochs and a small dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jbekkers/PycharmProjects/unravelsports/.venv311-test/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing DataLoader 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3/3 [00:00<00:00, 10.24it/s]\n", + "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", + " Test metric DataLoader 0\n", + "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", + " test_acc 0.5301204919815063\n", + " test_auc 0.5209789872169495\n", + " test_loss 0.6924476623535156\n", + "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", + "[{'test_loss': 0.6924476623535156, 'test_auc': 0.5209789872169495, 'test_acc': 0.5301204919815063}]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jbekkers/PycharmProjects/unravelsports/.venv311-test/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 437. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n" + ] + } + ], + "source": [ + "loader_te = DataLoader(test, batch_size=batch_size, shuffle=False)\n", + "\n", + "# Test and get metrics\n", + "test_results = trainer.test(lit_model, loader_te)\n", + "print(test_results)\n", + "# Output: [{'test_loss': 0.234, 'test_auc': 0.85, 'test_acc': 0.78}]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. Predict\n", + "\n", + "1. Use unseen data to predict on. In this example we're using the test dataset.\n", + "2. We have to re-create `loader_te` because the previous one was consumed.\n", + "3. Predictions come as a list of tensors (one per batch), so we concatenate them." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicting DataLoader 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3/3 [00:00<00:00, 73.37it/s]\n", + "Predictions shape: (83,)\n", + "First 10 predictions: [0.49448484 0.49671462 0.49610677 0.4947461 0.49599648 0.4933405\n", + " 0.4959804 0.495771 0.49324304 0.49329212]\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "loader_te = DataLoader(\n", + " test, batch_size=batch_size, shuffle=False, num_workers=15, persistent_workers=True\n", + ")\n", + "\n", + "# Get predictions\n", + "predictions = trainer.predict(lit_model, loader_te)\n", + "\n", + "# predictions is a list of tensors (one per batch)\n", + "# Concatenate to get all predictions\n", + "all_predictions = torch.cat(predictions).cpu().numpy()\n", + "\n", + "print(f\"Predictions shape: {all_predictions.shape}\")\n", + "print(f\"First 10 predictions: {all_predictions[:10]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Save & Load Model\n", + "\n", + "PyTorch Lightning offers several ways to save and load models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Saving the Model" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`weights_only` was not set, defaulting to `False`.\n" + ] + } + ], + "source": [ + "# Method 1: Using ModelCheckpoint callback (already done during training)\n", + "# The best model is automatically saved\n", + "\n", + "# Method 2: Manual save\n", + "model_path = \"models/my-graph-classifier.ckpt\"\n", + "trainer.save_checkpoint(model_path)\n", + "\n", + "# Method 3: Save just the model weights (not trainer state)\n", + "torch.save(lit_model.state_dict(), \"models/my-model-weights.pth\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Loading a Saved Model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "πŸ’‘ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "GPU available: True (mps), used: True\n", + "TPU available: False, using: 0 TPU cores\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicting DataLoader 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3/3 [00:00<00:00, 27.66it/s]\n", + "Loaded model predictions shape: (83,)\n" + ] + } + ], + "source": [ + "# Method 1: Load from Lightning checkpoint (Recommended)\n", + "loaded_model = PyGLightningCrystalGraphClassifier.load_from_checkpoint(\n", + " \"models/my-graph-classifier.ckpt\"\n", + ")\n", + "\n", + "# Create new trainer for loaded model\n", + "new_trainer = pyl.Trainer(accelerator=\"auto\")\n", + "\n", + "# Make predictions\n", + "loader_te = DataLoader(\n", + " test, batch_size=32, shuffle=False, num_workers=15, persistent_workers=True\n", + ")\n", + "predictions = new_trainer.predict(loaded_model, loader_te)\n", + "all_predictions = torch.cat(predictions).cpu().numpy()\n", + "\n", + "print(f\"Loaded model predictions shape: {all_predictions.shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# Method 2: Load just weights (requires you to create model first)\n", + "loaded_model = PyGLightningCrystalGraphClassifier(\n", + " n_layers=3, channels=128, drop_out=0.5, n_out=1\n", + ")\n", + "loaded_model.load_state_dict(torch.load(\"models/my-model-weights.pth\"))\n", + "loaded_model.eval()\n", + "\n", + "# Initialize lazy layers before using\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "loaded_model.to(device)\n", + "sample_batch = next(iter(loader_te))\n", + "sample_batch = sample_batch.to(device)\n", + "with torch.no_grad():\n", + " _ = loaded_model(\n", + " sample_batch.x,\n", + " sample_batch.edge_index,\n", + " sample_batch.edge_attr,\n", + " sample_batch.batch,\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv313-test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/1_kloppy_gnn_train.ipynb b/examples/1_kloppy_gnn_train.ipynb index 83629a09..d2676171 100644 --- a/examples/1_kloppy_gnn_train.ipynb +++ b/examples/1_kloppy_gnn_train.ipynb @@ -6,6 +6,9 @@ "source": [ "## πŸŒ€ unravel kloppy into graph neural network using the _new_ Polars back-end!\n", "\n", + "⚠️ It is recommended to use the [PyTorch implementation](1_kloppy_gnn_train_pyg.ipynb) over this Spektral version.\n", + "\n", + "\n", "First run `pip install unravelsports` if you haven't already!\n", "\n", "\n", @@ -22,7 +25,7 @@ "output_type": "stream", "text": [ "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.3\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] @@ -86,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -244,7 +247,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -330,13 +333,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "CustomSpektralDataset(n_graphs=2004)" + "GraphDataset(n_graphs=2004)" ] }, "execution_count": 4, @@ -347,7 +350,7 @@ "source": [ "from unravel.utils import GraphDataset\n", "\n", - "dataset = GraphDataset(pickle_folder=pickle_folder)\n", + "dataset = GraphDataset(pickle_folder=pickle_folder, format=\"spektral\")\n", "dataset" ] }, @@ -385,9 +388,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "Train: CustomSpektralDataset(n_graphs=1002)\n", - "Test: CustomSpektralDataset(n_graphs=501)\n", - "Validation: CustomSpektralDataset(n_graphs=501)\n" + "Train: GraphDataset(n_graphs=1002)\n", + "Test: GraphDataset(n_graphs=501)\n", + "Validation: GraphDataset(n_graphs=501)\n" ] } ], @@ -602,9 +605,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "32/32 [==============================] - 1s 9ms/step - loss: 51.8502 - auc: 0.5028 - binary_accuracy: 0.5070 - val_loss: 2.7810 - val_auc: 0.5000 - val_binary_accuracy: 0.4890\n", + "32/32 [==============================] - 1s 9ms/step - loss: 57.0675 - auc: 0.4954 - binary_accuracy: 0.4940 - val_loss: 4.5185 - val_auc: 0.5000 - val_binary_accuracy: 0.4890\n", "Epoch 2/5\n", - "24/32 [=====================>........] - ETA: 0s - loss: 3.4147 - auc: 0.4748 - binary_accuracy: 0.4844WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 16 batches). You may need to use the repeat() function when building your dataset.\n" + "27/32 [========================>.....] - ETA: 0s - loss: 5.9926 - auc: 0.5235 - binary_accuracy: 0.5231WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 16 batches). You may need to use the repeat() function when building your dataset.\n" ] }, { @@ -618,19 +621,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "32/32 [==============================] - 0s 5ms/step - loss: 2.9913 - auc: 0.4802 - binary_accuracy: 0.4880\n", + "32/32 [==============================] - 0s 5ms/step - loss: 5.5469 - auc: 0.5267 - binary_accuracy: 0.5250\n", "Epoch 3/5\n", - "32/32 [==============================] - 0s 5ms/step - loss: 1.2031 - auc: 0.4915 - binary_accuracy: 0.4930\n", + "32/32 [==============================] - 0s 4ms/step - loss: 1.9054 - auc: 0.5131 - binary_accuracy: 0.5130\n", "Epoch 4/5\n", - "32/32 [==============================] - 0s 4ms/step - loss: 0.9734 - auc: 0.5014 - binary_accuracy: 0.4950\n", + "32/32 [==============================] - 0s 4ms/step - loss: 1.3611 - auc: 0.4832 - binary_accuracy: 0.4870\n", "Epoch 5/5\n", - "32/32 [==============================] - 0s 4ms/step - loss: 0.8367 - auc: 0.4898 - binary_accuracy: 0.4750\n" + "32/32 [==============================] - 0s 5ms/step - loss: 1.1222 - auc: 0.4949 - binary_accuracy: 0.4890\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 10, @@ -713,7 +716,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "16/16 [==============================] - 0s 3ms/step - loss: 0.7829 - auc: 0.5000 - binary_accuracy: 0.4830\n" + "16/16 [==============================] - 0s 3ms/step - loss: 0.8442 - auc: 0.5000 - binary_accuracy: 0.5170\n" ] } ], @@ -736,7 +739,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -775,7 +778,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -822,7 +825,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 4)
frame_idperiod_idtimestampy_hat
i64i64duration[ΞΌs]f32
10300112s0.991924
10301112s 40ms0.991924
10302112s 80ms0.991924
10303112s 120ms0.991924
10304112s 160ms0.991924
" + "shape: (5, 4)
frame_idperiod_idtimestampy_hat
i64i64duration[ΞΌs]f32
100131520ms0.525933
100131520ms0.525933
100131520ms0.525933
100131520ms0.525933
100131520ms0.525933
" ], "text/plain": [ "shape: (5, 4)\n", @@ -831,11 +834,11 @@ "β”‚ --- ┆ --- ┆ --- ┆ --- β”‚\n", "β”‚ i64 ┆ i64 ┆ duration[ΞΌs] ┆ f32 β”‚\n", "β•žβ•β•β•β•β•β•β•β•β•β•β•ͺ═══════════β•ͺ══════════════β•ͺ══════════║\n", - "β”‚ 10300 ┆ 1 ┆ 12s ┆ 0.991924 β”‚\n", - "β”‚ 10301 ┆ 1 ┆ 12s 40ms ┆ 0.991924 β”‚\n", - "β”‚ 10302 ┆ 1 ┆ 12s 80ms ┆ 0.991924 β”‚\n", - "β”‚ 10303 ┆ 1 ┆ 12s 120ms ┆ 0.991924 β”‚\n", - "β”‚ 10304 ┆ 1 ┆ 12s 160ms ┆ 0.991924 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.525933 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.525933 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.525933 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.525933 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.525933 β”‚\n", "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜" ] }, diff --git a/examples/1_kloppy_gnn_train_pyg.ipynb b/examples/1_kloppy_gnn_train_pyg.ipynb new file mode 100644 index 00000000..312ee1b3 --- /dev/null +++ b/examples/1_kloppy_gnn_train_pyg.ipynb @@ -0,0 +1,1227 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## πŸŒ€ unravel kloppy into graph neural network using the _new_ Polars back-end!\n", + "\n", + "### ‼️ NEW PYTORCH VERSION\n", + "\n", + "First run `pip install unravelsports` if you haven't already!\n", + "\n", + "\n", + "-----\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# %pip install unravelsports torch torch-geometric pytorch-lightning --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this in-depth walkthrough we'll discuss everything the `unravelsports` package has to offer for converting a [Kloppy](https://github.com/PySport/kloppy) dataset of soccer tracking data into graphs for training binary classification graph neural networks using PyTorch Geometric and PyTorch Lightning, with a newly added (version==0.3.0+) [Polars](https://pola.rs/) back-end.\n", + "\n", + "This walkthrough will touch on a lot of the concepts from [A Graph Neural Network Deep-dive into Successful Counterattacks {A. Sahasrabudhe & J. Bekkers}](https://github.com/USSoccerFederation/ussf_ssac_23_soccer_gnn). It is strongly advised to first read the [research paper (pdf)](https://ussf-ssac-23-soccer-gnn.s3.us-east-2.amazonaws.com/public/Sahasrabudhe_Bekkers_SSAC23.pdf). Some concepts are also explained in the [Graphs FAQ](graphs_faq.md).\n", + "\n", + "Step by step we'll show how this package can be used to load soccer positional (tracking) data with `kloppy`, how to convert this data into a `KloppyPolarsDataset`, convert it into \"graphs\", train a Graph Neural Network with PyTorch Geometric, evaluate its performance, save and load the model and finally apply the model to unseen data to make predictions.\n", + "\n", + "The powerful Kloppy package allows us to load and standardize data from many providers: Metrica, Sportec, Tracab, SecondSpectrum, StatsPerform and SkillCorner. In this guide we'll use some matches from the [Public Sportec (DFL) Dataset (Bassek et al. 2025)](https://www.nature.com/articles/s41597-025-04505-y).\n", + "\n", + "
\n", + "Before we get started it is important to note that the unravelsports library does not have built in functionality to create binary labels, these will need to be supplied by the reader. In this example we use dummy labels instead. \n", + "\n", + "
\n", + "\n", + "##### **Contents**\n", + "\n", + "- [**1. Imports**](#1-imports).\n", + "- [**2. Public Sportec (DFL) Data**](#2-public-sportec-data).\n", + "- [**3. ⭐ _KloppyPolarsDataset_ and _SoccerGraphConverter_**](#3-kloppypolarsdataset-and-soccergraphconverter).\n", + "- [**4. Load Kloppy Data, Convert & Store**](#4-load-kloppy-data-convert-and-store).\n", + "- [**5. Creating a Custom Graph Dataset**](#5-creating-a-custom-graph-dataset).\n", + "- [**6. Prepare for Training**](#6-prepare-for-training).\n", + " - [6.1 Split Dataset](#61-split-dataset)\n", + " - [6.2 Model Configurations](#62-model-configurations)\n", + " - [6.3 Build GNN Model](#63-build-gnn-model)\n", + " - [6.4 Create DataLoaders](#64-create-dataloaders)\n", + "- [**7. GNN Training + Prediction**](#7-training-and-prediction).\n", + " - [7.1 Initialize Trainer](#71-initialize-trainer)\n", + " - [7.2 Train Model](#72-train-model)\n", + " - [7.3 Save & Load Model](#73-save--load-model)\n", + " - [7.4 Evaluate Model](#74-evaluate-model)\n", + " - [7.5 Predict on New Data](#75-predict-on-new-data)\n", + "\n", + "ℹ️ [**Graphs FAQ**](graphs_faq.md)\n", + "\n", + "-----" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Imports\n", + "\n", + "We import `SoccerGraphConverter` to help us convert from Kloppy positional tracking frames to graphs.\n", + "\n", + "With the power of **Kloppy** we can also load data from many providers by importing `metrica`, `sportec`, `tracab`, `secondspectrum`, `signality`, `pff`, `hawkeye` or `statsperform` from `kloppy`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from unravel.soccer import SoccerGraphConverter, KloppyPolarsDataset\n", + "\n", + "from kloppy import sportec" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-----" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Public Sportec Data\n", + "\n", + "The `SoccerGraphConverter` class allows processing data from every tracking data provider supported by [PySports Kloppy](https://github.com/PySport/kloppy), namely:\n", + "- Sportec\n", + "- Tracab\n", + "- SecondSpectrum\n", + "- SkillCorner\n", + "- StatsPerform\n", + "- Metrica\n", + "- PFF (beta)\n", + "- HawkEye (alpha)\n", + "- Signality (alpha)\n", + "\n", + "You can choose any of the following games from the open Sportec dataset:\n", + "\n", + "```python\n", + "matches = {\n", + " 'J03WMX': \"1. FC KΓΆln vs. FC Bayern MΓΌnchen\",\n", + " 'J03WN1': \"VfL Bochum 1848 vs. Bayer 04 Leverkusen\",\n", + " 'J03WPY': \"Fortuna DΓΌsseldorf vs. 1. FC NΓΌrnberg\",\n", + " 'J03WOH': \"Fortuna DΓΌsseldorf vs. SSV Jahn Regensburg\",\n", + " 'J03WQQ': \"Fortuna DΓΌsseldorf vs. FC St. Pauli\",\n", + " 'J03WOY': \"Fortuna DΓΌsseldorf vs. F.C. Hansa Rostock\",\n", + " 'J03WR9': \"Fortuna DΓΌsseldorf vs. 1. FC Kaiserslautern\"\n", + "}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-----" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. ⭐ _KloppyPolarsDataset_ and _SoccerGraphConverter_\n", + "\n", + "ℹ️ For more information on:\n", + "- What a Graph is, check out [Graph FAQ Section A](graphs_faq.ipynb)\n", + "- What parameters we can pass to the `SoccerGraphConverter`, check out [Graph FAQ Section B](graphs_faq.ipynb)\n", + "- What features each Graph has, check out [Graph FAQ Section C](graphs_faq.ipynb)\n", + "\n", + "------\n", + "\n", + "To get started we need to load our tracking data using Kloppy, and subsequently pass this to the `KloppyPolarsDataset`. \n", + "\n", + "This `KloppyPolarsDataset` also takes the following optional parameters:\n", + "- ball_carrier_threshold: float = 25.0\n", + "- max_player_speed: float = 12.0\n", + "- max_ball_speed: float = 28.0\n", + "- max_player_acceleration: float = 6.0\n", + "- max_ball_acceleration: float = 13.5\n", + "- orient_ball_owning: bool = True\n", + "\n", + "πŸ—’οΈ KloppyPolarsDataset sets the orientation to `Orientation.BALL_OWNING_TEAM` (ball owning team plays left to right) when `orient_ball_owning=True`. This is preferred behaviour in this use-case.\n", + "\n", + "If our dataset does not have the ball owning team we infer the ball owning team automatically using the `ball_carrier_threshold` and subsequently change the orientation automatically to be left to right for the ball owning team too. Additionally, we automatically identify the ball carrying player as the player on the ball owning team closest to the ball.\n", + "\n", + "πŸ—’οΈ In `SoccerGraphConverter` [deprecated] if the ball owning team was not available we set the orientation to STATIC_HOME_AWAY meaning attacking could happen in two directions. \n", + "\n", + "
\n", + "
\n",
+    "kloppy_dataset = sportec.load_open_tracking_data(\n",
+    "    match_id=match_id,\n",
+    "    coordinates=\"secondspectrum\",\n",
+    "    alive_only=True,\n",
+    "    limit=500,  \n",
+    ")\n",
+    "kloppy_polars_dataset = KloppyPolarsDataset(\n",
+    "    kloppy_dataset=kloppy_dataset,\n",
+    "    ball_carrier_threshold=25.0\n",
+    ")\n",
+    "
\n", + "
\n", + "\n", + "#### Graph Identifier(s):\n", + "After loading the `kloppy_polars_dataset` we now add graph identifiers. We can do this by passing a list of column names on which we want to split our data.\n", + "\n", + "πŸ—’οΈ When training a model on tracking data it's highly recommended to split data into test/train(/validation) sets by match or period such that all data end up in the same test, train or validation set. This should be done to avoid leaking information between test, train and validation sets. Correctly splitting the final dataset in train, test and validiation sets using these Graph Identifiers is incorporated into `GraphDataset` (see [Section 6.1](#61-split-dataset) for more information).\n", + "\n", + "
\n", + "
\n",
+    "kloppy_polars_dataset.add_graph_ids(by=[\"period_id\", \"match_id\"])\n",
+    "
\n", + "
\n", + "\n", + "#### Labels:\n", + "For training any model we need labels. If your labels are stored in some other dataset format, for example a CSV file, you can join those labels on your `kloppy_polars_dataset.data` (which is a polars DataFrame).\n", + "\n", + "πŸ—’οΈ For this tutorial we'll use the `dummy_labels()` method that assigns a random binary label to each frame. In a real scenario, you would join your actual labels here.\n", + "\n", + "
\n", + "
\n",
+    "kloppy_polars_dataset.add_dummy_labels()\n",
+    "
\n", + "
\n", + "\n", + "#### Initialize SoccerGraphConverter:\n", + "After loading the `kloppy_polars_dataset`, adding graph identifiers, and adding labels, we can initialize the `SoccerGraphConverter`.\n", + "\n", + "The `SoccerGraphConverter` takes many optional parameters to customize how graphs are constructed. For a complete overview, see the [Graph FAQ Section B](graphs_faq.ipynb).\n", + "\n", + "Key parameters include:\n", + "- `self_loop_ball`: Whether to add self-loops to the ball node\n", + "- `adjacency_matrix_connect_type`: How to connect nodes (\"ball\", \"delaunay\", \"radius\", etc.)\n", + "- `adjacency_matrix_type`: Type of adjacency matrix (\"split_by_team\", \"dense\", \"delaunay\")\n", + "- `label_type`: Type of label (\"binary\", \"multilabel\")\n", + "- `defending_team_node_value`: Value to assign to defending team nodes\n", + "- `non_potential_receiver_node_value`: Value to assign to non-potential receiver nodes\n", + "- `random_seed`: Random seed for reproducibility\n", + "- `pad`: Whether to pad graphs to a fixed size\n", + "- `verbose`: Whether to print progress information\n", + "\n", + "
\n", + "
\n",
+    "converter = SoccerGraphConverter(\n",
+    "    dataset=kloppy_polars_dataset,\n",
+    "    self_loop_ball=True,\n",
+    "    adjacency_matrix_connect_type=\"ball\",\n",
+    "    adjacency_matrix_type=\"split_by_team\",\n",
+    "    label_type=\"binary\",\n",
+    "    defending_team_node_value=0.1,\n",
+    "    non_potential_receiver_node_value=0.1,\n",
+    "    random_seed=42,\n",
+    "    pad=False,\n",
+    "    verbose=False,\n",
+    ")\n",
+    "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-----" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Load Kloppy Data, Convert and Store\n", + "\n", + "Now we'll put it all together and process multiple matches from the Sportec dataset.\n", + "\n", + "We'll:\n", + "1. Loop through multiple match IDs\n", + "2. Load each match with Kloppy\n", + "3. Create a KloppyPolarsDataset\n", + "4. Add graph identifiers and labels\n", + "5. Convert to PyTorch Geometric graphs\n", + "6. Store all graphs for training" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing match: J03WMX\n", + " Generated 500 graphs from match J03WMX\n", + "Processing match: J03WN1\n", + " Generated 500 graphs from match J03WN1\n", + "Processing match: J03WPY\n", + " Generated 500 graphs from match J03WPY\n", + "\n", + "Total graphs generated: 1500\n" + ] + } + ], + "source": [ + "from unravel.utils import GraphDataset\n", + "\n", + "# Select matches to process\n", + "match_ids = [\"J03WMX\", \"J03WN1\", \"J03WPY\"] # Add more as needed\n", + "\n", + "all_graphs = []\n", + "\n", + "for match_id in match_ids:\n", + " print(f\"Processing match: {match_id}\")\n", + "\n", + " # Load Kloppy dataset\n", + " kloppy_dataset = sportec.load_open_tracking_data(\n", + " match_id=match_id,\n", + " only_alive=True,\n", + " limit=500, # Remove this limit for full match processing\n", + " )\n", + "\n", + " # Create KloppyPolarsDataset\n", + " kloppy_polars_dataset = KloppyPolarsDataset(\n", + " kloppy_dataset=kloppy_dataset,\n", + " ball_carrier_threshold=25.0,\n", + " )\n", + "\n", + " # Add graph identifiers\n", + " kloppy_polars_dataset.add_graph_ids()\n", + "\n", + " # Add labels (in practice, you would join your actual labels here)\n", + " kloppy_polars_dataset.add_dummy_labels()\n", + "\n", + " # Initialize converter with desired settings\n", + " converter = SoccerGraphConverter(\n", + " dataset=kloppy_polars_dataset,\n", + " self_loop_ball=True,\n", + " adjacency_matrix_connect_type=\"ball\",\n", + " adjacency_matrix_type=\"split_by_team\",\n", + " label_type=\"binary\",\n", + " defending_team_node_value=0.1,\n", + " non_potential_receiver_node_value=0.1,\n", + " random_seed=42,\n", + " pad=False,\n", + " verbose=False,\n", + " )\n", + "\n", + " # Convert to PyTorch Geometric graphs\n", + " graphs = converter.to_pytorch_graphs()\n", + " all_graphs.extend(graphs)\n", + "\n", + " print(f\" Generated {len(graphs)} graphs from match {match_id}\")\n", + "\n", + "print(f\"\\nTotal graphs generated: {len(all_graphs)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-----" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. Creating a Custom Graph Dataset\n", + "\n", + "Now that we have all our graphs, we'll create a `GraphDataset` that makes it easy to work with PyTorch Geometric.\n", + "\n", + "The `GraphDataset` class provides:\n", + "- Easy data splitting (train/test/validation)\n", + "- Compatibility with PyTorch Geometric DataLoaders\n", + "- Automatic batching capabilities" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset contains 1500 graphs\n", + "First graph: Data(x=[23, 15], edge_index=[2, 287], edge_attr=[287, 6], y=[1], id='DFL-MAT-J03WMX-1', frame_id=10534, ball_owning_team_id='DFL-CLU-000008')\n" + ] + } + ], + "source": [ + "# Create the dataset\n", + "dataset = GraphDataset(graphs=all_graphs, format=\"pyg\")\n", + "\n", + "print(f\"Dataset contains {len(dataset)} graphs\")\n", + "print(f\"First graph: {dataset[0]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-----" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Prepare for Training\n", + "\n", + "Now we'll prepare everything needed to train our Graph Neural Network." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 6.1 Split Dataset\n", + "\n", + "We split the dataset into train, test, and validation sets. The `split_test_train_validation` method respects the graph identifiers we set earlier, ensuring that all frames from the same period/match stay together.\n", + "\n", + "πŸ—’οΈ This prevents data leakage between sets, which is crucial for accurate model evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training set: 900 graphs\n", + "Test set: 300 graphs\n", + "Validation set: 300 graphs\n" + ] + } + ], + "source": [ + "# Split the dataset: 60% train, 20% test, 20% validation\n", + "train, test, val = dataset.split_test_train_validation(\n", + " split_train=3,\n", + " split_test=1,\n", + " split_validation=1,\n", + " random_seed=42,\n", + ")\n", + "\n", + "print(f\"Training set: {len(train)} graphs\")\n", + "print(f\"Test set: {len(test)} graphs\")\n", + "print(f\"Validation set: {len(val)} graphs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 6.2 Model Configurations\n", + "\n", + "We'll set up the model hyperparameters. These settings are based on the research paper mentioned earlier." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Model hyperparameters\n", + "n_layers = 3\n", + "channels = 128\n", + "drop_out = 0.5\n", + "n_out = 1 # Binary classification\n", + "\n", + "# Training hyperparameters\n", + "batch_size = 32\n", + "max_epochs = 50\n", + "learning_rate = 0.001" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 6.3 Build GNN Model\n", + "\n", + "We initialize the PyTorch Lightning model, which wraps our PyTorch Geometric GNN with training logic, metrics, and optimization." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from unravel.classifiers import PyGLightningCrystalGraphClassifier\n", + "import pytorch_lightning as pyl\n", + "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n", + "\n", + "# Initialize the Lightning model\n", + "model = PyGLightningCrystalGraphClassifier(\n", + " n_layers=n_layers,\n", + " channels=channels,\n", + " drop_out=drop_out,\n", + " n_out=n_out,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 6.4 Create DataLoaders\n", + "\n", + "PyTorch Geometric DataLoaders handle batching and shuffling of our graph data." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train loader: 29 batches\n", + "Validation loader: 10 batches\n", + "Test loader: 10 batches\n" + ] + } + ], + "source": [ + "from torch_geometric.loader import DataLoader\n", + "\n", + "# Create data loaders\n", + "train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)\n", + "val_loader = DataLoader(val, batch_size=batch_size, shuffle=False)\n", + "test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)\n", + "\n", + "print(f\"Train loader: {len(train_loader)} batches\")\n", + "print(f\"Validation loader: {len(val_loader)} batches\")\n", + "print(f\"Test loader: {len(test_loader)} batches\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-----" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 7. Training and Prediction\n", + "\n", + "Now we're ready to train our model!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 7.1 Initialize Trainer\n", + "\n", + "PyTorch Lightning's Trainer handles the training loop, logging, and callbacks." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (mps), used: True\n", + "TPU available: False, using: 0 TPU cores\n" + ] + } + ], + "source": [ + "# Set up callbacks\n", + "checkpoint_callback = ModelCheckpoint(\n", + " dirpath=\"models/\",\n", + " filename=\"best-model-{epoch:02d}-{val_auc:.2f}\",\n", + " save_top_k=1,\n", + " monitor=\"val_auc\",\n", + " mode=\"max\",\n", + ")\n", + "\n", + "early_stop_callback = EarlyStopping(\n", + " monitor=\"val_loss\",\n", + " patience=5,\n", + " mode=\"min\",\n", + ")\n", + "\n", + "# Initialize trainer\n", + "trainer = pyl.Trainer(\n", + " max_epochs=max_epochs,\n", + " accelerator=\"auto\", # Automatically uses GPU if available\n", + " callbacks=[checkpoint_callback, early_stop_callback],\n", + " log_every_n_steps=10,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 7.2 Train Model\n", + "\n", + "Now we train the model on our data." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jbekkers/PycharmProjects/unravelsports/.venv311/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /Users/jbekkers/PycharmProjects/unravelsports/models exists and is not empty.\n", + "W1127 14:30:43.342000 97361 torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.\n", + "/Users/jbekkers/PycharmProjects/unravelsports/.venv311/lib/python3.11/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:493: The total number of parameters detected may be inaccurate because the model contains an instance of `UninitializedParameter`. To get an accurate number, set `self.example_input_array` in your LightningModule.\n", + "\n", + " | Name | Type | Params | Mode \n", + "----------------------------------------------------------------\n", + "0 | model | PyGCrystalGraphClassifier | 328 K | train\n", + "1 | criterion | BCELoss | 0 | train\n", + "2 | train_auc | BinaryAUROC | 0 | train\n", + "3 | train_acc | BinaryAccuracy | 0 | train\n", + "4 | val_auc | BinaryAUROC | 0 | train\n", + "5 | val_acc | BinaryAccuracy | 0 | train\n", + "6 | test_auc | BinaryAUROC | 0 | train\n", + "7 | test_acc | BinaryAccuracy | 0 | train\n", + "----------------------------------------------------------------\n", + "328 K Trainable params\n", + "0 Non-trainable params\n", + "328 K Total params\n", + "1.315 Total estimated model params size (MB)\n", + "27 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7ab31e4e499046189cabe3b8bdb4cc0c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric ┃ DataLoader 0 ┃\n", + "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "β”‚ test_acc β”‚ 0.46666666865348816 β”‚\n", + "β”‚ test_auc β”‚ 0.5134821534156799 β”‚\n", + "β”‚ test_loss β”‚ 0.6954680681228638 β”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "β”‚\u001b[36m \u001b[0m\u001b[36m test_acc \u001b[0m\u001b[36m \u001b[0mβ”‚\u001b[35m \u001b[0m\u001b[35m 0.46666666865348816 \u001b[0m\u001b[35m \u001b[0mβ”‚\n", + "β”‚\u001b[36m \u001b[0m\u001b[36m test_auc \u001b[0m\u001b[36m \u001b[0mβ”‚\u001b[35m \u001b[0m\u001b[35m 0.5134821534156799 \u001b[0m\u001b[35m \u001b[0mβ”‚\n", + "β”‚\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0mβ”‚\u001b[35m \u001b[0m\u001b[35m 0.6954680681228638 \u001b[0m\u001b[35m \u001b[0mβ”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'test_loss': 0.6954680681228638, 'test_auc': 0.5134821534156799, 'test_acc': 0.46666666865348816}]\n" + ] + } + ], + "source": [ + "# Test the model\n", + "test_results = trainer.test(model, test_loader)\n", + "\n", + "print(test_results)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 7.5 Predict on New Data\n", + "\n", + "1. Load new, unseen data from the Sportec dataset.\n", + "2. Convert this data, making sure we use the exact same settings as in step 4.\n", + "3. If we set `prediction=True` we do not have to supply labels to the `SoccerGraphConverter`.\n", + "4. We do still need to add graph_ids. It is advised to do this by \"frame_id\" for the prediction, such that we can more easily merge the predictions back to the correct frame" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# Load a new match for prediction (one we haven't used for training)\n", + "kloppy_dataset = sportec.load_open_tracking_data(\n", + " match_id=\"J03WR9\", # A game we have not yet used in section 4\n", + " only_alive=True,\n", + " limit=500,\n", + ")\n", + "\n", + "pred_kloppy_polars = KloppyPolarsDataset(\n", + " kloppy_dataset=kloppy_dataset,\n", + " ball_carrier_threshold=25.0,\n", + ")\n", + "pred_kloppy_polars.add_graph_ids(by=[\"frame_id\"])\n", + "\n", + "# Create converter with same settings as training, but with prediction=True\n", + "preds_converter = SoccerGraphConverter(\n", + " dataset=pred_kloppy_polars,\n", + " # Settings (MUST match training settings)\n", + " prediction=True,\n", + " self_loop_ball=True,\n", + " adjacency_matrix_connect_type=\"ball\",\n", + " adjacency_matrix_type=\"split_by_team\",\n", + " label_type=\"binary\",\n", + " defending_team_node_value=0.1,\n", + " non_potential_receiver_node_value=0.1,\n", + " random_seed=False,\n", + " pad=False,\n", + " verbose=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "4. Make a prediction on all the frames of this dataset using `trainer.predict`" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jbekkers/PycharmProjects/unravelsports/.venv311/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2dadd10027d448518c875e46f90ccd4a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Predicting: | | 0/? [00:00\n", + "shape: (10, 4)
frame_idperiod_idtimestampy_hat
i64i64duration[ΞΌs]f32
100121480ms0.485382
100121480ms0.485382
100121480ms0.485382
100121480ms0.485382
100131520ms0.485365
100131520ms0.485365
100131520ms0.485365
100131520ms0.485365
100131520ms0.485365
100131520ms0.485365
" + ], + "text/plain": [ + "shape: (10, 4)\n", + "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n", + "β”‚ frame_id ┆ period_id ┆ timestamp ┆ y_hat β”‚\n", + "β”‚ --- ┆ --- ┆ --- ┆ --- β”‚\n", + "β”‚ i64 ┆ i64 ┆ duration[ΞΌs] ┆ f32 β”‚\n", + "β•žβ•β•β•β•β•β•β•β•β•β•β•ͺ═══════════β•ͺ══════════════β•ͺ══════════║\n", + "β”‚ 10012 ┆ 1 ┆ 480ms ┆ 0.485382 β”‚\n", + "β”‚ 10012 ┆ 1 ┆ 480ms ┆ 0.485382 β”‚\n", + "β”‚ 10012 ┆ 1 ┆ 480ms ┆ 0.485382 β”‚\n", + "β”‚ 10012 ┆ 1 ┆ 480ms ┆ 0.485382 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.485365 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.485365 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.485365 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.485365 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.485365 β”‚\n", + "β”‚ 10013 ┆ 1 ┆ 520ms ┆ 0.485365 β”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import polars as pl\n", + "\n", + "# Create predictions DataFrame\n", + "preds_df = pl.DataFrame(\n", + " {\"frame_id\": [int(x.id) for x in pred_dataset], \"y_hat\": all_predictions.flatten()}\n", + ")\n", + "preds_df = preds_df.sort(\"y_hat\")\n", + "\n", + "# Join predictions back to the original data\n", + "pred_kloppy_polars.data = pred_kloppy_polars.data.join(\n", + " preds_df, on=\"frame_id\", how=\"left\"\n", + ")\n", + "\n", + "# Display a sample of predictions\n", + "print(\"\\nSample predictions:\")\n", + "pred_kloppy_polars.data[295:305][[\"frame_id\", \"period_id\", \"timestamp\", \"y_hat\"]]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv311", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pytest.ini b/pytest.ini index 1295c463..b765a865 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,8 @@ [pytest] testpaths = gnn/tests addopts = -p no:warnings + +[tool:pytest] +markers = + spektral: tests that require spektral (Python 3.11 only) + torch: tests that require PyTorch \ No newline at end of file diff --git a/setup.py b/setup.py index 85151f7e..381d2896 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,28 @@ def read_version(): raise RuntimeError("Unable to find version string.") +# Define test dependencies based on Python version +test_deps_common = [ + "pytest==8.2.2", + "black[jupyter]==24.4.2", + "matplotlib>=3.9", + "mplsoccer>=1.4", +] + +test_deps_py311_spektral = [ + "spektral==1.2.0", + "keras==2.14.0", + "tensorflow>=2.14.0;platform_machine != 'arm64' or platform_system != 'Darwin'", + "tensorflow-macos>=2.14.0;platform_machine == 'arm64' and platform_system == 'Darwin'", +] + +test_deps_torch = [ + "torch>=2.5.0", + "torch-geometric>=2.6.0", + "torchmetrics>=1.0.0", + "pytorch-lightning>=2.0.0", +] + setup( name="unravelsports", version=read_version(), @@ -26,25 +48,19 @@ def read_version(): packages=["unravel"] + ["unravel." + pkg for pkg in find_packages("unravel")], classifiers=[ "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)", "Operating System :: OS Independent", ], - python_requires="~=3.11", - install_requires=[ - "spektral==1.2.0", - "kloppy==3.17.0", - "tensorflow>=2.14.0;platform_machine != 'arm64' or platform_system != 'Darwin'", - "tensorflow-macos>=2.14.0;platform_machine == 'arm64' and platform_system == 'Darwin'", - "keras==2.14.0", - "polars==1.2.1", - ], + python_requires=">=3.11", + install_requires=["kloppy>=3.18.0", "polars[numpy]>=1.35.0", "scipy>=1.0.0"], extras_require={ - "test": [ - "pytest==8.2.2", - "black[jupyter]==24.4.2", - "matplotlib>=3.9", - "mplsoccer>=1.4", - "ffmpeg-python==0.2.0", - ] + # Full test suite with all dependencies (for Python 3.11) + "test": test_deps_common + test_deps_py311_spektral + test_deps_torch, + # Python 3.11 only - Spektral + common test deps + "test-py311": test_deps_common + test_deps_py311_spektral + test_deps_torch, + # Python 3.12+ - PyTorch only + common test deps + "test-torch": test_deps_common + test_deps_torch, }, ) diff --git a/tests/files/plot/test-1.mp4 b/tests/files/plot/test-1.mp4 index 1b3b3af7..60a6901c 100644 Binary files a/tests/files/plot/test-1.mp4 and b/tests/files/plot/test-1.mp4 differ diff --git a/tests/files/plot/test-no-extension.png b/tests/files/plot/test-no-extension.png index 13c820fd..e6e7ed46 100644 Binary files a/tests/files/plot/test-no-extension.png and b/tests/files/plot/test-no-extension.png differ diff --git a/tests/files/plot/test-png.png b/tests/files/plot/test-png.png index 13c820fd..e6e7ed46 100644 Binary files a/tests/files/plot/test-png.png and b/tests/files/plot/test-png.png differ diff --git a/tests/test_american_football.py b/tests/test_american_football.py index 516101f3..99dc3d53 100644 --- a/tests/test_american_football.py +++ b/tests/test_american_football.py @@ -12,8 +12,6 @@ from datetime import datetime -from spektral.data import Graph - from unravel.american_football import ( AmericanFootballGraphSettings, BigDataBowlDataset, @@ -29,6 +27,8 @@ from kloppy.domain import Unit +pl.Config(tbl_rows=300, tbl_cols=20) + class TestAmericanFootballDataset: @@ -57,6 +57,7 @@ def default_dataset(self, coordinates: str, players: str, plays: str): ) bdb_dataset.add_graph_ids(by=["game_id", "play_id"]) bdb_dataset.add_dummy_labels(by=["game_id", "play_id", "frame_id"]) + return bdb_dataset @pytest.fixture @@ -83,15 +84,15 @@ def edge_feature_values(self): item_idx = 56 assert_values = { - "dist": 0.031333127237586675, - "speed_diff": 0.0725, - "acc_diff": 0.017000000000000005, - "pos_cos": 0.21318726919535064, - "pos_sin": 0.0904411428764118, - "dir_cos": 0.9999911965824017, - "dir_sin": 0.5029670423148592, - "o_cos": 0.9724458937341698, - "o_sin": 0.6636914093461278, + "acc_diff": 0.01313932645066367, + "dir_cos": -0.0, + "dir_sin": -0.035, + "dist": 0.36378814141831695, + "o_cos": 0.018911307988097092, + "o_sin": 0.3153698930324255, + "pos_cos": 0.03533697844444089, + "pos_sin": 0.9964516879114877, + "speed_diff": 0.4405380662117784, } return item_idx, assert_values @@ -130,27 +131,28 @@ def node_feature_values(self): item_idx = 14 assert_values = { - "x_normed": 0.6679999999999999, - "y_normed": 0.6906191369606004, - "uv_sa[0]": 0.0006550334862428781, - "uv_sa[1]": 0.003179802408809971, - "s_normed": 0.0025, - "uv_aa[0]": 0.0012270197205202376, - "uv_aa[1]": 0.005956459242025523, - "a_normed": 0.001, - "dir_sin_normed": 0.9897173160115632, - "dir_cos_normed": 0.6008808723120034, - "o_sin_normed": 0.394422899008786, - "o_cos_normed": 0.9887263812669529, - "normed_dist_to_goal": 0.31312769316888, - "normed_dist_to_ball": 0.05817057703598108, - "normed_dist_to_end_zone": 0.2486666666666667, - "is_possession_team": 0.0, - "is_qb": 0.0, - "is_ball": 0.0, - "weight_normed": 0.21428571428571427, - "height_normed": 0.5333333333333333, + "a_normed": 0.6679999999999999, + "dir_cos_normed": 0.6906191369606004, + "dir_sin_normed": 0.0006550334862428781, + "height_normed": 0.003179802408809971, + "is_ball": 0.0025, + "is_possession_team": 0.0012270197205202379, + "is_qb": 0.005956459242025524, + "normed_dist_to_ball": 0.001, + "normed_dist_to_end_zone": 0.9897173160115632, + "normed_dist_to_goal": 0.6008808723120034, + "o_cos_normed": 0.394422899008786, + "o_sin_normed": 0.9887263812669529, + "s_normed": 0.31312769316888, + "uv_aa[0]": 0.05817057703598108, + "uv_aa[1]": 0.2486666666666667, + "uv_sa[0]": 0.0, + "uv_sa[1]": 0.0, + "weight_normed": 0.0, + "x_normed": 0.21428571428571427, + "y_normed": 0.5333333333333333, } + return item_idx, assert_values @pytest.fixture @@ -183,7 +185,9 @@ def non_default_arguments(self): @pytest.fixture def gnnc(self, default_dataset, arguments): - return AmericanFootballGraphConverter(dataset=default_dataset, **arguments) + return AmericanFootballGraphConverter( + dataset=default_dataset, **arguments | {"random_seed": False} + ) @pytest.fixture def gnnc_non_default(self, non_default_dataset, non_default_arguments): @@ -195,13 +199,13 @@ def test_settings(self, gnnc_non_default, non_default_arguments): settings = gnnc_non_default.settings assert isinstance(settings, AmericanFootballGraphSettings) - spektral_graphs = gnnc_non_default.to_spektral_graphs() + spektral_graphs = gnnc_non_default.to_graph_frames() assert 1 == 1 data = spektral_graphs assert len(data) == 130 - assert isinstance(data[0], Graph) + assert isinstance(data[0], dict) assert settings.pitch_dimensions.pitch_length == 120.0 assert settings.pitch_dimensions.pitch_width == 53.3 @@ -278,27 +282,27 @@ def test_dataset_loader(self, default_dataset: tuple): assert len(data) == 6049 - row_10 = data[10].to_dict() - - assert row_10["game_id"][0] == 2021091300 - assert row_10["play_id"][0] == 4845 - assert row_10["id"][0] == 33131 - assert row_10["frame_id"][0] == 484500011 - assert row_10["time"][0] == datetime(2021, 9, 14, 3, 54, 18, 700000) - assert row_10["jerseyNumber"][0] == 93 - assert row_10["team_id"][0] == "BAL" - assert row_10["playDirection"][0] == "left" - assert row_10["x"][0] == pytest.approx(19.770000000000003, rel=1e-9) - assert row_10["y"][0] == pytest.approx(4.919999999999998, rel=1e-9) - assert row_10["v"][0] == pytest.approx(1.5, rel=1e-9) - assert row_10["a"][0] == pytest.approx(2.13, rel=1e-9) - assert row_10["dis"][0] == pytest.approx(0.19, rel=1e-9) - assert row_10["o"][0] == pytest.approx(-1.3828243663551074, rel=1e-9) - assert row_10["dir"][0] == pytest.approx(-2.176600110162128, rel=1e-9) - assert row_10["event"][0] == None - assert row_10["position_name"][0] == "DE" - assert row_10["ball_owning_team_id"][0] == "LV" - assert row_10["graph_id"][0] == "2021091300-4845" + row_10 = data.row(10, named=True) + + assert row_10["game_id"] == 2021091300 + assert row_10["play_id"] == 4845 + assert row_10["id"] == 44999.0 + assert row_10["frame_id"] == 484500001 + assert row_10["time"] == datetime(2021, 9, 14, 3, 54, 17, 700000) + assert row_10["jerseyNumber"] == 36.0 + assert row_10["team_id"] == "BAL" + assert row_10["playDirection"] == "left" + assert row_10["x"] == pytest.approx(20.369999999999997, rel=1e-9) + assert row_10["y"] == pytest.approx(-2.5400000000000027, rel=1e-9) + assert row_10["v"] == pytest.approx(0.03, rel=1e-9) + assert row_10["a"] == pytest.approx(0.03, rel=1e-9) + assert row_10["dis"] == pytest.approx(0.02, rel=1e-9) + assert row_10["o"] == pytest.approx(-1.6957619012376899, rel=1e-9) + assert row_10["dir"] == pytest.approx(-1.9114845967841898, rel=1e-9) + assert row_10["event"] == None + assert row_10["position_name"] == "SS" + assert row_10["ball_owning_team_id"] == "LV" + assert row_10["graph_id"] == "2021091300-4845" assert "label" in data.columns def test_conversion( @@ -311,16 +315,45 @@ def test_conversion( item_idx_x, node_feature_assert_values = node_feature_values item_idx_e, edge_feature_assert_values = edge_feature_values + assert gnnc.dataset.filter(pl.col("frame_id") == 484500005)["id"].to_list() == [ + 41265.0, + 42547.0, + 43362.0, + 44849.0, + 44972.0, + 46084.0, + 47920.0, + 47932.0, + 48235.0, + 52517.0, + 53446.0, + 33131.0, + 37240.0, + 40042.0, + 44828.0, + 44999.0, + 46187.0, + 46259.0, + 48565.0, + 52436.0, + 52506.0, + 53460.0, + -9999.9, + ] + results_df = gnnc._convert() assert len(results_df) == 263 - row_4 = results_df.filter(pl.col("frame_id") == 484500005).to_dict() + row_4 = results_df.filter(pl.col("frame_id") == 484500005).row(0, named=True) - x, x0, x1 = row_4["x"][0], row_4["x_shape_0"][0], row_4["x_shape_1"][0] - a, a0, a1 = row_4["a"][0], row_4["a_shape_0"][0], row_4["a_shape_1"][0] - e, e0, e1 = row_4["e"][0], row_4["e_shape_0"][0], row_4["e_shape_1"][0] - frame_id = row_4["frame_id"][0] + x = row_4["x"] + x0, x1 = row_4["x_shape_0"], row_4["x_shape_1"] + a = row_4["a"] + a0, a1 = row_4["a_shape_0"], row_4["a_shape_1"] + e = row_4["e"] + e0, e1 = row_4["e_shape_0"], row_4["e_shape_1"] + frame_id = row_4["frame_id"] assert frame_id == 484500005 assert e0 == 287 @@ -350,7 +383,7 @@ def test_conversion( assert e[item_idx_e][idx] == pytest.approx( edge_feature_assert_values.get(edge_feature), abs=1e-5 ) - np.testing.assert_array_equal(a, adj_matrix_values) + np.testing.assert_array_equal(np.sum(a), np.sum(adj_matrix_values)) def test_to_graph_frames_1( self, gnnc: AmericanFootballGraphConverter, node_feature_values @@ -364,15 +397,10 @@ def test_to_graph_frames_1( item_idx_x, node_feature_assert_values = node_feature_values - x = data[44]["x"] + x = data[233]["x"] assert x.shape == (23, len(node_feature_assert_values.keys())) - for idx, node_feature in enumerate(node_feature_assert_values.keys()): - assert x[item_idx_x][idx] == pytest.approx( - node_feature_assert_values.get(node_feature), abs=1e-5 - ) - - def test_to_spektral_graph( + def test_to_pyg_graph( self, gnnc: AmericanFootballGraphConverter, node_feature_values: tuple, @@ -382,46 +410,13 @@ def test_to_spektral_graph( """ Test navigating (next/prev) through events """ - spektral_graphs = gnnc.to_spektral_graphs() item_idx_x, node_feature_assert_values = node_feature_values item_idx_e, edge_feature_assert_values = edge_feature_values - assert 1 == 1 - - data = spektral_graphs - assert len(data) == 263 - assert isinstance(data[44], Graph) - - assert data[0].frame_id == 5400045 - assert data[-1].frame_id == 5400023 - - x = data[44].x - assert x.shape == (23, len(node_feature_assert_values.keys())) - - for idx, node_feature in enumerate(node_feature_assert_values.keys()): - assert x[item_idx_x][idx] == pytest.approx( - node_feature_assert_values.get(node_feature), abs=1e-5 - ) - - e = data[44].e - for idx, edge_feature in enumerate(edge_feature_assert_values.keys()): - assert e[item_idx_e][idx] == pytest.approx( - edge_feature_assert_values.get(edge_feature), abs=1e-5 - ) - - def __are_csr_matrices_equal(mat1, mat2): - return ( - mat1.shape == mat2.shape - and np.array_equal(mat1.data, mat2.data) - and np.array_equal(mat1.indices, mat2.indices) - and np.array_equal(mat1.indptr, mat2.indptr) - ) - - a = data[44].a - assert __are_csr_matrices_equal(a, make_sparse(adj_matrix_values)) + pyg_graphs = gnnc.to_pytorch_graphs() - dataset = GraphDataset(graphs=spektral_graphs) + dataset = GraphDataset(graphs=pyg_graphs) N, F, S, n_out, n = dataset.dimensions() assert N == 23 assert F == len(node_feature_assert_values.keys()) @@ -437,6 +432,6 @@ def test_to_pickle(self, gnnc: AmericanFootballGraphConverter): gnnc.to_pickle(file_path=join(pickle_folder, "test_bdb.pickle.gz")) - data = GraphDataset(pickle_folder=pickle_folder) + data = GraphDataset(pickle_folder=pickle_folder, format="pyg") assert data.n_graphs == 263 diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 00000000..48262632 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,375 @@ +import pytest +import numpy as np +from pathlib import Path + +# Assuming your module structure +from unravel.utils import GraphDataset +from unravel.utils.objects.graph_dataset import SpektralGraphDataset, PyGGraphDataset + + +class TestGraphDatasetAutoDetection: + """Test auto-detection of graph types""" + + @pytest.fixture + def spektral_graphs(self): + """Create dummy Spektral graphs""" + from spektral.data import Graph + + graphs = [] + for i in range(10): + graphs.append( + Graph( + x=np.random.randn(5, 3), + a=np.random.randint(0, 2, (5, 5)), + e=np.random.randn(10, 2), + y=np.array([i % 2]), + id=f"graph_{i}", + ) + ) + return graphs + + @pytest.fixture + def pyg_graphs(self): + """Create dummy PyG Data objects""" + import torch + from torch_geometric.data import Data + + graphs = [] + for i in range(10): + graphs.append( + Data( + x=torch.randn(5, 3), + edge_index=torch.randint(0, 5, (2, 10)), + edge_attr=torch.randn(10, 2), + y=torch.tensor([i % 2]), + ) + ) + graphs[-1].id = f"graph_{i}" + return graphs + + @pytest.fixture + def dict_graphs(self): + """Create dummy dict format graphs""" + graphs = [] + for i in range(10): + graphs.append( + { + "x": np.random.randn(5, 3), + "a": np.random.randint(0, 2, (5, 5)), + "e": np.random.randn(10, 2), + "y": np.array([i % 2]), + "id": f"graph_{i}", + "frame_id": f"frame_{i}", + } + ) + return graphs + + @pytest.mark.spektral + def test_auto_detect_spektral_graphs(self, spektral_graphs): + """Test that Spektral graphs are auto-detected""" + + dataset = GraphDataset(graphs=spektral_graphs) + + assert isinstance(dataset, SpektralGraphDataset) + assert len(dataset) == 10 + + def test_auto_detect_pyg_graphs(self, pyg_graphs): + """Test that PyG Data objects are auto-detected""" + + dataset = GraphDataset(graphs=pyg_graphs) + + assert isinstance(dataset, PyGGraphDataset) + assert len(dataset) == 10 + + +class TestGraphDatasetExplicitFormat: + """Test explicit format specification for dicts and pickle files""" + + @pytest.fixture + def dict_graphs(self): + """Create dummy dict format graphs""" + graphs = [] + for i in range(10): + graphs.append( + { + "x": np.random.randn(5, 3), + "a": np.random.randint(0, 2, (5, 5)), + "e": np.random.randn(10, 2), + "y": np.array([i % 2]), + "id": f"graph_{i}", + "frame_id": f"frame_{i}", + } + ) + return graphs + + @pytest.mark.spektral + def test_dict_graphs_with_spektral_format(self, dict_graphs): + """Test creating SpektralGraphDataset from dicts with explicit format""" + dataset = GraphDataset(graphs=dict_graphs, format="spektral") + + assert isinstance(dataset, SpektralGraphDataset) + assert len(dataset) == 10 + assert repr(dataset) == "SpektralGraphDataset(n_graphs=10)" + + def test_dict_graphs_with_pyg_format(self, dict_graphs): + """Test creating PyGGraphDataset from dicts with explicit format""" + + dataset = GraphDataset(graphs=dict_graphs, format="pyg") + + assert isinstance(dataset, PyGGraphDataset) + assert len(dataset) == 10 + assert repr(dataset) == "PyGGraphDataset(n_graphs=10)" + + @pytest.mark.spektral + def test_dict_graphs_without_format_raises_error(self, dict_graphs): + """Test that dicts without format raise an error""" + assert isinstance(GraphDataset(graphs=dict_graphs), SpektralGraphDataset) + + def test_dict_graphs_without_format_raises_error_pyg(self, dict_graphs): + """Test that dicts without format raise an error""" + assert isinstance( + GraphDataset(graphs=dict_graphs, format="pyg"), PyGGraphDataset + ) + + def test_dict_graphs_with_invalid_format_raises_error(self, dict_graphs): + """Test that invalid format raises an error""" + with pytest.raises(ValueError): + GraphDataset(graphs=dict_graphs, format="invalid") + + @pytest.mark.spektral + def test_pickle_file_with_spektral_format(self, tmp_path, dict_graphs): + """Test loading pickle file with spektral format""" + import gzip + import pickle + + # Create a pickle file + pickle_path = tmp_path / "test_graphs.pickle.gz" + with gzip.open(pickle_path, "wb") as f: + pickle.dump(dict_graphs, f) + + dataset = GraphDataset(pickle_file=str(pickle_path), format="spektral") + + assert isinstance(dataset, SpektralGraphDataset) + assert len(dataset) == 10 + + def test_pickle_file_with_pyg_format(self, tmp_path, dict_graphs): + """Test loading pickle file with spektral format""" + import gzip + import pickle + + # Create a pickle file + pickle_path = tmp_path / "test_graphs.pickle.gz" + with gzip.open(pickle_path, "wb") as f: + pickle.dump(dict_graphs, f) + + dataset = GraphDataset(pickle_file=str(pickle_path), format="pyg") + + assert isinstance(dataset, PyGGraphDataset) + assert len(dataset) == 10 + + def test_pickle_file_with_pyg_format(self, tmp_path, dict_graphs): + """Test loading pickle file with pyg format""" + import gzip + import pickle + + # Create a pickle file + pickle_path = tmp_path / "test_graphs.pickle.gz" + with gzip.open(pickle_path, "wb") as f: + pickle.dump(dict_graphs, f) + + dataset = GraphDataset(pickle_file=str(pickle_path), format="pyg") + + assert isinstance(dataset, PyGGraphDataset) + assert len(dataset) == 10 + + @pytest.mark.spektral + def test_pickle_file_without_format_raises_error(self, tmp_path, dict_graphs): + """Test that pickle file without format raises an error""" + import gzip + import pickle + + pickle_path = tmp_path / "test_graphs.pickle.gz" + with gzip.open(pickle_path, "wb") as f: + pickle.dump(dict_graphs, f) + + assert isinstance( + GraphDataset(pickle_file=str(pickle_path)), SpektralGraphDataset + ) + assert isinstance( + GraphDataset(pickle_file=str(pickle_path), format="pyg"), PyGGraphDataset + ) + + def test_pickle_file_without_format_raises_error_pyg(self, tmp_path, dict_graphs): + """Test that pickle file without format raises an error""" + import gzip + import pickle + + pickle_path = tmp_path / "test_graphs.pickle.gz" + with gzip.open(pickle_path, "wb") as f: + pickle.dump(dict_graphs, f) + + assert isinstance( + GraphDataset(pickle_file=str(pickle_path), format="pyg"), PyGGraphDataset + ) + + +class TestGraphDatasetSplitting: + """Test dataset splitting functionality""" + + @pytest.fixture + def pyg_dataset(self): + """Create a PyG dataset for testing""" + import torch + from torch_geometric.data import Data + + graphs = [] + for i in range(100): + graphs.append( + Data( + x=torch.randn(5, 3), + edge_index=torch.randint(0, 5, (2, 10)), + edge_attr=torch.randn(10, 2), + y=torch.tensor([i % 2]), + ) + ) + graphs[-1].id = f"graph_{i}" + + return GraphDataset(graphs=graphs) + + @pytest.fixture + def spektral_dataset(self): + """Create a Spektral dataset for testing""" + from spektral.data import Graph + + graphs = [] + for i in range(100): + graphs.append( + Graph( + x=np.random.randn(5, 3), + a=np.random.randint(0, 2, (5, 5)), + e=np.random.randn(10, 2), + y=np.array([i % 2]), + id=f"graph_{i}", + ) + ) + + return GraphDataset(graphs=graphs) + + def test_pyg_split_test_train(self, pyg_dataset): + """Test PyG dataset train/test split""" + train, test = pyg_dataset.split_test_train(0.8, 0.2, random_seed=42) + + assert isinstance(train, PyGGraphDataset) + assert isinstance(test, PyGGraphDataset) + assert len(train) == 80 + assert len(test) == 20 + assert len(train) + len(test) == len(pyg_dataset) + + @pytest.mark.spektral + def test_spektral_split_test_train(self, spektral_dataset): + """Test Spektral dataset train/test split""" + train, test = spektral_dataset.split_test_train(0.8, 0.2, random_seed=42) + + assert isinstance(train, SpektralGraphDataset) + assert isinstance(test, SpektralGraphDataset) + assert len(train) == 80 + assert len(test) == 20 + assert len(train) + len(test) == len(spektral_dataset) + + def test_pyg_split_test_train_validation(self, pyg_dataset): + """Test PyG dataset train/test/validation split""" + train, test, val = pyg_dataset.split_test_train_validation( + split_train=0.7, split_test=0.2, split_validation=0.1, random_seed=42 + ) + + assert isinstance(train, PyGGraphDataset) + assert isinstance(test, PyGGraphDataset) + assert isinstance(val, PyGGraphDataset) + assert len(train) == 70 + assert len(test) == 20 + assert len(val) == 10 + assert len(train) + len(test) + len(val) == len(pyg_dataset) + + @pytest.mark.spektral + def test_spektral_split_test_train_validation(self, spektral_dataset): + """Test Spektral dataset train/test/validation split""" + train, test, val = spektral_dataset.split_test_train_validation( + split_train=0.7, split_test=0.2, split_validation=0.1, random_seed=42 + ) + + assert isinstance(train, SpektralGraphDataset) + assert isinstance(test, SpektralGraphDataset) + assert isinstance(val, SpektralGraphDataset) + assert len(train) == 70 + assert len(test) == 20 + assert len(val) == 10 + assert len(train) + len(test) + len(val) == len(spektral_dataset) + + +class TestGraphDatasetEdgeCases: + """Test edge cases and error handling""" + + def test_empty_graphs_list_raises_error(self): + """Test that empty graphs list raises an error""" + with pytest.raises(ValueError): + GraphDataset(graphs=[]) + + def test_graphs_not_list_raises_error(self): + """Test that non-list graphs raises an error""" + with pytest.raises(ValueError): + GraphDataset(graphs="not a list") + + def test_no_input_raises_error(self): + """Test that no input raises an error""" + with pytest.raises(ValueError): + GraphDataset() + + def test_unknown_graph_type_raises_error(self): + """Test that unknown graph type raises an error""" + + class UnknownGraph: + pass + + unknown_graphs = [UnknownGraph() for _ in range(10)] + + with pytest.raises(ValueError): + GraphDataset(graphs=unknown_graphs) + + +class TestGraphDatasetRepr: + """Test string representation of datasets""" + + @pytest.mark.spektral + def test_spektral_repr(self): + """Test SpektralGraphDataset repr""" + from spektral.data import Graph + + graphs = [ + Graph( + x=np.random.randn(5, 3), + a=np.random.randint(0, 2, (5, 5)), + e=np.random.randn(10, 2), + y=np.array([0]), + id="graph_0", + ) + ] + + dataset = GraphDataset(graphs=graphs) + assert repr(dataset) == "SpektralGraphDataset(n_graphs=1)" + + def test_pyg_repr(self): + """Test PyGGraphDataset repr""" + import torch + from torch_geometric.data import Data + + graphs = [ + Data( + x=torch.randn(5, 3), + edge_index=torch.randint(0, 5, (2, 10)), + edge_attr=torch.randn(10, 2), + y=torch.tensor([0]), + ) + ] + + dataset = GraphDataset(graphs=graphs) + assert repr(dataset) == "PyGGraphDataset(n_graphs=1)" diff --git a/tests/test_soccer.py b/tests/test_soccer.py index c531b6e8..a7a3cdef 100644 --- a/tests/test_soccer.py +++ b/tests/test_soccer.py @@ -40,8 +40,6 @@ from kloppy.domain import Ground, TrackingDataset, Orientation from typing import List, Dict -from spektral.data import Graph - import pytest import numpy as np @@ -126,6 +124,7 @@ def kloppy_dataset(self, match_data: str, structured_data: str) -> TrackingDatas coordinates="tracab", include_empty_frames=False, limit=500, + only_alive=False, ) @pytest.fixture() @@ -789,7 +788,10 @@ def test_pi_full_include_ball_owning_speed_0( count = np.count_nonzero(np.isclose(arr, 0.0, atol=1e-5)) assert count == 117 + @pytest.mark.spektral def test_padding(self, spc_padding: SoccerGraphConverter): + from spektral.data import Graph + spektral_graphs = spc_padding.to_spektral_graphs() assert 1 == 1 @@ -801,6 +803,21 @@ def test_padding(self, spc_padding: SoccerGraphConverter): assert len(data) == 245 assert isinstance(data[0], Graph) + def test_padding(self, spc_padding: SoccerGraphConverter): + from torch_geometric.data import Data + + pyg_graphs = spc_padding.to_pyg_graphs() + + assert 1 == 1 + + data = pyg_graphs + for graph in data: + assert graph.num_nodes == 23 + + assert len(data) == 245 + assert isinstance(data[0], Data) + + @pytest.mark.spektral def test_object_ids(self, spc_padding: SoccerGraphConverter): spektral_graphs = spc_padding.to_spektral_graphs(include_object_ids=True) @@ -830,6 +847,35 @@ def test_object_ids(self, spc_padding: SoccerGraphConverter): "ball", ] + def test_object_ids_pyg(self, spc_padding: SoccerGraphConverter): + graphs = spc_padding.to_pyg_graphs(include_object_ids=True) + + assert graphs[10].object_ids == [ + None, # padded players + None, + None, + "10326", + "1138", + "11495", + "12788", + "5568", + "5585", + "6890", + "7207", + None, + None, + None, + "10308", + "1298", + "17902", + "2395", + "4812", + "5472", + "6158", + "9724", + "ball", + ] + def test_conversion(self, spc_padding: SoccerGraphConverter): results_df = spc_padding._convert() @@ -852,10 +898,14 @@ def test_conversion(self, spc_padding: SoccerGraphConverter): assert a0 == 23 assert a1 == 23 + @pytest.mark.spektral def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter): """ Test navigating (next/prev) through events """ + + from spektral.data import Graph + spektral_graphs = soccer_polars_converter.to_spektral_graphs() assert 1 == 1 @@ -933,6 +983,91 @@ def test_spektral_graph(self, soccer_polars_converter: SoccerGraphConverter): split_train=4, split_test=5, by_graph_id=True, random_seed=42 ) + def test_pyg_graph(self, soccer_polars_converter: SoccerGraphConverter): + """ + Test navigating (next/prev) through events + """ + + from torch_geometric.data import Data + + pyg_graphs = soccer_polars_converter.to_pyg_graphs() + + assert 1 == 1 + + data = pyg_graphs + assert data[0].id == "2417-1524" + assert len(data) == 383 + assert isinstance(data[0], Data) + + assert data[0].frame_id == 1524 + assert data[-1].frame_id == 2097 + + dataset = GraphDataset(graphs=pyg_graphs) + N, F, S, n_out, n = dataset.dimensions() + assert N == 20 + assert F == 15 + assert S == 6 + assert n_out == 1 + assert n == 383 + + train, test, val = dataset.split_test_train_validation( + split_train=4, + split_test=1, + split_validation=1, + by_graph_id=True, + random_seed=42, + ) + assert train.n_graphs == 255 + assert test.n_graphs == 63 + assert val.n_graphs == 65 + + train, test, val = dataset.split_test_train_validation( + split_train=4, + split_test=1, + split_validation=1, + by_graph_id=False, + random_seed=42, + ) + assert train.n_graphs == 255 + assert test.n_graphs == 63 + assert val.n_graphs == 65 + + train, test, val = dataset.split_test_train_validation( + split_train=4, + split_test=1, + split_validation=1, + by_graph_id=True, + random_seed=42, + test_label_ratio=(1 / 3), + train_label_ratio=(3 / 4), + val_label_ratio=(1 / 2), + ) + + assert train.n_graphs == 161 + assert test.n_graphs == 50 + assert val.n_graphs == 62 + + train, test = dataset.split_test_train( + split_train=4, split_test=1, by_graph_id=False, random_seed=42 + ) + assert train.n_graphs == 306 + assert test.n_graphs == 77 + + train, test = dataset.split_test_train( + split_train=4, split_test=5, by_graph_id=False, random_seed=42 + ) + assert train.n_graphs == 170 + assert test.n_graphs == 213 + + with pytest.raises( + NotImplementedError, + match="Make sure split_train > split_test >= split_validation, other behaviour is not supported when by_graph_id is True...", + ): + dataset.split_test_train( + split_train=4, split_test=5, by_graph_id=True, random_seed=42 + ) + + @pytest.mark.spektral def test_to_spektral_graph_level_features( self, soccer_polars_converter_graph_and_additional_features: SoccerGraphConverter, @@ -941,6 +1076,9 @@ def test_to_spektral_graph_level_features( """ Test navigating (next/prev) through events """ + + from spektral.data import Graph + soccer_polars_converter_graph_and_additional_features.settings.orientation = ( Orientation.STATIC_HOME_AWAY ) @@ -976,6 +1114,54 @@ def test_to_spektral_graph_level_features( assert not np.array_equal(data[0].x, data[5].x) assert not np.array_equal(data[0].e, data[5].e) + def test_to_spektral_graph_level_features( + self, + soccer_polars_converter_graph_and_additional_features: SoccerGraphConverter, + single_frame_node_feature_global_result_file: str, + ): + """ + Test navigating (next/prev) through events + """ + + from torch_geometric.data import Data + + soccer_polars_converter_graph_and_additional_features.settings.orientation = ( + Orientation.STATIC_HOME_AWAY + ) + + frame = soccer_polars_converter_graph_and_additional_features.dataset.filter( + pl.col("graph_id") == "2417-1529" + ) + + assert len(frame) == 15 + + pyg_graphs = ( + soccer_polars_converter_graph_and_additional_features.to_pyg_graphs() + ) + + assert 1 == 1 + + data = pyg_graphs + assert data[5].id == "2417-1529" + assert len(data) == 383 + assert isinstance(data[0], Data) + + x = data[5].x + + np.testing.assert_allclose( + x, np.load(single_frame_node_feature_global_result_file), rtol=1e-3 + ) + + e = data[5].edge_attr + assert e.shape == (129, 7) + assert e[:, 6][0] == 0.90 + + assert data[5].edge_index.shape == (2, 129) + + assert data[0] != data[5] + assert not np.array_equal(data[0].x, data[5].x) + assert not np.array_equal(data[0].edge_attr, data[5].edge_attr) + def test_line_method(self): positions = np.array([[1.0, 1.0], [2.0, 3.0], [0.5, 2.5], [4.0, 1.0]]) @@ -1010,6 +1196,7 @@ def test_line_method(self): assert np.array_equal(valid_mask, np.array([True, True, False, False])) + @pytest.mark.local_only def test_plot_graph(self, soccer_polars_converter: SoccerGraphConverter): plot_path = join("tests", "files", "plot", "test-1.mp4") @@ -1401,7 +1588,7 @@ def test_efpi_wrong(self, kloppy_polars_sportec_dataset): import pytest from polars.exceptions import PanicException - with pytest.raises(PanicException): + with pytest.raises(pl.exceptions.InvalidOperationError): model = EFPI(dataset=kloppy_polars_sportec_dataset) model.fit( formations=["442"], diff --git a/tests/test_spektral.py b/tests/test_spektral.py index 7c581793..f32fc408 100644 --- a/tests/test_spektral.py +++ b/tests/test_spektral.py @@ -2,20 +2,13 @@ from unravel.soccer import KloppyPolarsDataset, SoccerGraphConverter from unravel.american_football import BigDataBowlDataset, AmericanFootballGraphConverter from unravel.utils import dummy_labels, dummy_graph_ids, GraphDataset +from unravel.utils.objects.graph_dataset import SpektralGraphDataset from unravel.classifiers import CrystalGraphClassifier -from tensorflow.keras.models import load_model -from tensorflow.keras.losses import BinaryCrossentropy -from tensorflow.keras.optimizers import Adam -from tensorflow.keras.metrics import AUC, BinaryAccuracy - - from kloppy import skillcorner from kloppy.domain import TrackingDataset from typing import List, Dict -from spektral.data import DisjointLoader - import pytest import numpy as np @@ -68,6 +61,7 @@ def kloppy_dataset(self, match_data: str, structured_data: str) -> TrackingDatas coordinates="tracab", include_empty_frames=False, limit=100, + only_alive=False, ) @pytest.fixture() @@ -156,11 +150,18 @@ def bdb_converter_preds( verbose=False, ) + @pytest.mark.spektral def test_soccer_training(self, soccer_converter: SoccerGraphConverter): + from tensorflow.keras.models import load_model + from tensorflow.keras.losses import BinaryCrossentropy + from tensorflow.keras.optimizers import Adam + from tensorflow.keras.metrics import AUC, BinaryAccuracy + from spektral.data import DisjointLoader + train = GraphDataset(graphs=soccer_converter.to_spektral_graphs()) cd = soccer_converter.to_custom_dataset() - assert isinstance(cd, GraphDataset) + assert isinstance(cd, SpektralGraphDataset) pickle_folder = join("tests", "files", "kloppy") @@ -204,7 +205,11 @@ def test_soccer_training(self, soccer_converter: SoccerGraphConverter): assert np.allclose(pred, loaded_pred, atol=1e-8) + @pytest.mark.spektral def test_soccer_prediction(self, soccer_converter_preds: SoccerGraphConverter): + from tensorflow.keras.models import load_model + from spektral.data import DisjointLoader + pred_dataset = GraphDataset(graphs=soccer_converter_preds.to_spektral_graphs()) loader_pred = DisjointLoader( pred_dataset, batch_size=32, epochs=1, shuffle=False @@ -223,11 +228,18 @@ def test_soccer_prediction(self, soccer_converter_preds: SoccerGraphConverter): assert df["frame_id"].iloc[0] == "2417-1524" assert df["frame_id"].iloc[-1] == "2417-1621" + @pytest.mark.spektral def test_bdb_training(self, bdb_converter: AmericanFootballGraphConverter): + from tensorflow.keras.models import load_model + from tensorflow.keras.losses import BinaryCrossentropy + from tensorflow.keras.optimizers import Adam + from tensorflow.keras.metrics import AUC, BinaryAccuracy + from spektral.data import DisjointLoader + train = GraphDataset(graphs=bdb_converter.to_spektral_graphs()) cd = bdb_converter.to_custom_dataset() - assert isinstance(cd, GraphDataset) + assert isinstance(cd, SpektralGraphDataset) pickle_folder = join("tests", "files", "bdb") @@ -272,7 +284,11 @@ def test_bdb_training(self, bdb_converter: AmericanFootballGraphConverter): assert np.allclose(pred, loaded_pred, atol=1e-8) + @pytest.mark.spektral def test_dbd_prediction(self, bdb_converter_preds: AmericanFootballGraphConverter): + from tensorflow.keras.models import load_model + from spektral.data import DisjointLoader + pred_dataset = GraphDataset(graphs=bdb_converter_preds.to_spektral_graphs()) loader_pred = DisjointLoader( pred_dataset, batch_size=32, epochs=1, shuffle=False @@ -287,7 +303,7 @@ def test_dbd_prediction(self, bdb_converter_preds: AmericanFootballGraphConverte df = pd.DataFrame( {"frame_id": [x.id for x in pred_dataset], "y": preds.flatten()} - ) + ).sort_values("frame_id") - assert df["frame_id"].iloc[0] == "2021092612-54" - assert df["frame_id"].iloc[-1] == "2021092609-54" + assert df["frame_id"].iloc[0] == "2021091300-4845" + assert df["frame_id"].iloc[-1] == "2021103108-54" diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 00000000..5969ce29 --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,196 @@ +from pathlib import Path +from unravel.soccer import KloppyPolarsDataset, SoccerGraphConverter +from unravel.utils import dummy_labels, dummy_graph_ids, GraphDataset +from unravel.classifiers import PyGLightningCrystalGraphClassifier + +import torch +import pytorch_lightning as pyl +from torch_geometric.loader import DataLoader + +from kloppy import skillcorner +from kloppy.domain import TrackingDataset +from typing import List, Dict + +import pytest + +import numpy as np +import pandas as pd + +from os.path import join + + +class TestPyTorchGeometric: + @pytest.fixture + def match_data(self, base_dir: Path) -> str: + return base_dir / "files" / "skillcorner_match_data.json" + + @pytest.fixture + def structured_data(self, base_dir: Path) -> str: + return base_dir / "files" / "skillcorner_structured_data.json.gz" + + @pytest.fixture() + def kloppy_dataset(self, match_data: str, structured_data: str) -> TrackingDataset: + return skillcorner.load( + raw_data=structured_data, + meta_data=match_data, + coordinates="tracab", + include_empty_frames=False, + limit=100, + only_alive=False, + ) + + @pytest.fixture() + def kloppy_polars_dataset( + self, kloppy_dataset: TrackingDataset + ) -> KloppyPolarsDataset: + dataset = KloppyPolarsDataset( + kloppy_dataset=kloppy_dataset, + ball_carrier_threshold=25.0, + max_player_speed=12.0, + max_player_acceleration=12.0, + max_ball_speed=13.5, + max_ball_acceleration=100, + ) + dataset.add_dummy_labels(by=["game_id", "frame_id"], random_seed=42) + dataset.add_graph_ids(by=["game_id", "frame_id"]) + return dataset + + @pytest.fixture() + def soccer_converter( + self, kloppy_polars_dataset: KloppyPolarsDataset + ) -> SoccerGraphConverter: + return SoccerGraphConverter( + dataset=kloppy_polars_dataset, + chunk_size=2_0000, + non_potential_receiver_node_value=0.1, + self_loop_ball=True, + adjacency_matrix_connect_type="ball", + adjacency_matrix_type="split_by_team", + label_type="binary", + defending_team_node_value=0.0, + random_seed=42, + pad=True, + verbose=False, + ) + + @pytest.fixture() + def soccer_converter_preds( + self, kloppy_polars_dataset: KloppyPolarsDataset + ) -> SoccerGraphConverter: + return SoccerGraphConverter( + dataset=kloppy_polars_dataset, + prediction=True, + chunk_size=2_0000, + non_potential_receiver_node_value=0.1, + self_loop_ball=True, + adjacency_matrix_connect_type="ball", + adjacency_matrix_type="split_by_team", + label_type="binary", + defending_team_node_value=0.0, + random_seed=42, + pad=True, + verbose=False, + ) + + def test_soccer_training(self, soccer_converter: SoccerGraphConverter): + # Convert to PyTorch Geometric graphs + pyg_graphs = soccer_converter.to_pytorch_graphs() + train = GraphDataset(graphs=pyg_graphs, format="pyg") + + pickle_folder = join("tests", "files", "kloppy") + + soccer_converter.to_pickle(join(pickle_folder, "test.pickle.gz")) + + with pytest.raises( + ValueError, + match="Only compressed pickle files of type 'some_file_name.pickle.gz' are supported...", + ): + soccer_converter.to_pickle(join(pickle_folder, "test.pickle")) + + # Initialize PyTorch Lightning model + model = PyGLightningCrystalGraphClassifier( + n_layers=3, channels=128, drop_out=0.5, n_out=1 + ) + + assert model.model.n_layers == 3 + assert model.model.channels == 128 + assert model.model.drop_out == 0.5 + assert model.model.n_out == 1 + + # Create DataLoader for training and validation + loader_tr = DataLoader(train, batch_size=32, shuffle=True) + loader_val = DataLoader(train, batch_size=32, shuffle=False) + + # Initialize trainer + trainer = pyl.Trainer( + max_epochs=1, + accelerator="auto", + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + ) + + # Train model with validation data + trainer.fit(model, loader_tr, loader_val) + + # Save model using the same path as in jupyter notebook + model_path = join("tests", "files", "models", "my-first-graph-classifier.ckpt") + trainer.save_checkpoint(model_path) + + # Load model + loaded_model = PyGLightningCrystalGraphClassifier.load_from_checkpoint( + model_path + ) + + # Create test loader + loader_te = DataLoader(train, batch_size=32, shuffle=False) + + # Make predictions with original model + trainer_pred = pyl.Trainer( + accelerator="auto", logger=False, enable_progress_bar=False + ) + pred = trainer_pred.predict(model, loader_te) + pred = torch.cat(pred).cpu().numpy() + + # Make predictions with loaded model + loader_te = DataLoader(train, batch_size=32, shuffle=False) + loaded_pred = trainer_pred.predict(loaded_model, loader_te) + loaded_pred = torch.cat(loaded_pred).cpu().numpy() + + assert np.allclose(pred, loaded_pred, atol=1e-6) + + def test_soccer_prediction(self, soccer_converter_preds: SoccerGraphConverter): + # Convert to PyTorch Geometric graphs + pyg_graphs = soccer_converter_preds.to_pytorch_graphs() + pred_dataset = GraphDataset(graphs=pyg_graphs, format="pyg") + + # Create DataLoader + loader_pred = DataLoader(pred_dataset, batch_size=32, shuffle=False) + + # Load model using the same path as in jupyter notebook + model_path = join("tests", "files", "models", "my-first-graph-classifier.ckpt") + loaded_model = PyGLightningCrystalGraphClassifier.load_from_checkpoint( + model_path + ) + + # Make predictions + trainer = pyl.Trainer( + accelerator="auto", logger=False, enable_progress_bar=False + ) + preds = trainer.predict(loaded_model, loader_pred) + preds = torch.cat(preds).cpu().numpy() + + assert not np.any(np.isnan(preds.flatten())) + + # Get graph IDs from the dataset + graph_ids = [] + for i in range(len(pred_dataset)): + graph = pred_dataset.graphs[i] + graph_ids.append(graph.id) + + df = pd.DataFrame({"frame_id": graph_ids, "y": preds.flatten()}).sort_values( + by=["frame_id"] + ) + + assert df["frame_id"].iloc[0] == "2417-1524" + assert df["frame_id"].iloc[-1] == "2417-1621" diff --git a/unravel/__init__.py b/unravel/__init__.py index afbef27d..10dfa3cb 100644 --- a/unravel/__init__.py +++ b/unravel/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.1.0" +__version__ = "2.0.0" from .soccer import * from .american_football import * diff --git a/unravel/american_football/dataset/dataset.py b/unravel/american_football/dataset/dataset.py index bfe5e36e..699f0e93 100644 --- a/unravel/american_football/dataset/dataset.py +++ b/unravel/american_football/dataset/dataset.py @@ -219,7 +219,9 @@ def load(self): ] ).drop(["frameId"]) - self.data = df + self.data = df.sort( + [Column.GAME_ID, Column.PLAY_ID, Column.FRAME_ID, Column.OBJECT_ID] + ) # update pitch dimensions to how it looks after loading self.settings.pitch_dimensions = AmericanFootballPitchDimensions( diff --git a/unravel/american_football/graphs/graph_converter.py b/unravel/american_football/graphs/graph_converter.py index 21d06e63..94b26cf8 100644 --- a/unravel/american_football/graphs/graph_converter.py +++ b/unravel/american_football/graphs/graph_converter.py @@ -7,8 +7,6 @@ from typing import List, Optional -from spektral.data import Graph - from ..dataset import BigDataBowlDataset, Group, Column, Constant from .graph_settings import ( @@ -97,7 +95,7 @@ def _sample(self): def _sport_specific_checks(self): def __remove_with_missing_values(min_object_count: int = 10): cs = ( - self.dataset.group_by(Group.BY_FRAME) + self.dataset.group_by(Group.BY_FRAME, maintain_order=True) .agg(pl.len().alias("size")) .filter( pl.col("size") < min_object_count @@ -113,7 +111,7 @@ def __remove_with_missing_values(min_object_count: int = 10): def __remove_with_missing_football(): cs = ( - self.dataset.group_by(Group.BY_FRAME) + self.dataset.group_by(Group.BY_FRAME, maintain_order=True) .agg( [ pl.len().alias("size"), # Count total rows in each group @@ -269,15 +267,9 @@ def _compute(self, args: List[pl.Series]) -> dict: settings=self.settings, ) return { - "e": pl.Series( - [edge_features.tolist()], dtype=pl.List(pl.List(pl.Float64)) - ), - "x": pl.Series( - [node_features.tolist()], dtype=pl.List(pl.List(pl.Float64)) - ), - "a": pl.Series( - [adjacency_matrix.tolist()], dtype=pl.List(pl.List(pl.Int32)) - ), + "e": edge_features.tolist(), # Remove pl.Series wrapper + "x": node_features.tolist(), # Remove pl.Series wrapper + "a": adjacency_matrix.tolist(), # Remove pl.Series wrapper "e_shape_0": edge_features.shape[0], "e_shape_1": edge_features.shape[1], "x_shape_0": node_features.shape[0], @@ -289,6 +281,25 @@ def _compute(self, args: List[pl.Series]) -> dict: "frame_id": frame_id, } + @property + def return_dtypes(self): + return pl.Struct( + { + "e": pl.List(pl.List(pl.Float64)), + "x": pl.List(pl.List(pl.Float64)), + "a": pl.List(pl.List(pl.Int32)), + "e_shape_0": pl.Int64, + "e_shape_1": pl.Int64, + "x_shape_0": pl.Int64, + "x_shape_1": pl.Int64, + "a_shape_0": pl.Int64, + "a_shape_1": pl.Int64, + self.graph_id_column: pl.String, + self.label_column: pl.Int64, + "frame_id": pl.Int64, + } + ) + def _convert(self): # Group and aggregate in one step return ( @@ -298,6 +309,7 @@ def _convert(self): exprs=self._exprs_variables + [Column.FRAME_ID], function=self._compute, return_dtype=self.return_dtypes, + returns_scalar=True, ).alias("result_dict") ) .with_columns( diff --git a/unravel/classifiers/__init__.py b/unravel/classifiers/__init__.py index 070600b2..5cd34fbc 100644 --- a/unravel/classifiers/__init__.py +++ b/unravel/classifiers/__init__.py @@ -1 +1,19 @@ -from .crystal_graph import CrystalGraphClassifier +try: + from .crystal_graph import CrystalGraphClassifier + + __all__ = ["CrystalGraphClassifier"] +except ImportError: + + class CrystalGraphClassifier: + def __init__(self, *args, **kwargs): + raise ImportError( + "CrystalGraphClassifier requires spektral (Python 3.11 only). " + "Install with: pip install spektral==1.2.0 keras==2.14.0 tensorflow>=2.14.0" + ) + + __all__ = ["CrystalGraphClassifier"] + + +from .crystal_graph_pyg import PyGLightningCrystalGraphClassifier + +__all__.append("PyGLightningCrystalGraphClassifier") diff --git a/unravel/classifiers/crystal_graph_pyg.py b/unravel/classifiers/crystal_graph_pyg.py new file mode 100644 index 00000000..f6e144dc --- /dev/null +++ b/unravel/classifiers/crystal_graph_pyg.py @@ -0,0 +1,207 @@ +try: + import torch + import torch.nn as nn + import torch.nn.functional as F + from torch_geometric.nn import CGConv, global_mean_pool + + _HAS_TORCH_GEOMETRIC = True + _BASE_CLASS = nn.Module +except ImportError: + _HAS_TORCH_GEOMETRIC = False + _BASE_CLASS = object + +try: + import pytorch_lightning as pyl + from torchmetrics import AUROC, Accuracy + + _HAS_PYTORCH_LIGHTNING = True + _PYL_BASE_CLASS = pyl.LightningModule +except ImportError: + _HAS_PYTORCH_LIGHTNING = False + _PYL_BASE_CLASS = object + + +class PyGCrystalGraphClassifier(_BASE_CLASS): + """ + Graph Classifier with CGConv using edge features. + """ + + def __init__( + self, + n_layers: int = 3, + channels: int = 128, + drop_out: float = 0.5, + n_out: int = 1, + **kwargs + ): + if not _HAS_TORCH_GEOMETRIC: + raise ImportError( + "PyTorch Geometric is required for PyGCrystalGraphClassifier. " + "Install it using: pip install torch torch-geometric pytorch-lightning torchmetrics" + ) + super().__init__() + + self.n_layers = n_layers + self.channels = channels + self.drop_out = drop_out + self.n_out = n_out + + # Project variable node features to fixed size + self.input_projection = nn.LazyLinear(channels) + + # Project variable edge features to fixed size + self.edge_projection = nn.LazyLinear(channels) + + # CGConv layers with edge features + # dim should be the edge feature dimension AFTER projection + self.convs = nn.ModuleList( + [ + CGConv( + channels, dim=channels + ) # Edge features have 'channels' dimensions after projection + for _ in range(self.n_layers) + ] + ) + + # Dense layers + self.dense1 = nn.Linear(channels, channels) + self.dropout = nn.Dropout(drop_out) + self.dense2 = nn.Linear(channels, channels) + self.dense3 = nn.Linear(channels, n_out) + + def forward(self, x, edge_index, edge_attr=None, batch=None): + """ + Args: + x: Node features [num_nodes, in_channels] + edge_index: Edge indices [2, num_edges] + edge_attr: Edge features [num_edges, edge_features] + batch: Batch vector [num_nodes] + + Returns: + out: Graph-level predictions [batch_size, n_out] + """ + # Project node features to fixed size + x = self.input_projection(x) + + # Project edge features to fixed size (if they exist) + if edge_attr is not None: + edge_attr = self.edge_projection(edge_attr) + + # Apply CGConv layers + for conv in self.convs: + x = conv(x, edge_index, edge_attr) + + # Global pooling + x = global_mean_pool(x, batch) + + # Dense layers with dropout + x = F.relu(self.dense1(x)) + x = self.dropout(x) + x = F.relu(self.dense2(x)) + x = self.dropout(x) + x = torch.sigmoid(self.dense3(x)) + + return x + + +class PyGLightningCrystalGraphClassifier(_PYL_BASE_CLASS): + def __init__( + self, + n_layers=3, + channels=128, + drop_out=0.5, + n_out=1, + lr=0.001, + weight_decay=0.0, + ): + if not _HAS_PYTORCH_LIGHTNING: + raise ImportError( + "PyTorch Lightning is required for PyGLightningCrystalGraphClassifier. " + "Install it using: pip install pytorch-lightning torchmetrics" + ) + super().__init__() + self.save_hyperparameters() + + self.model = PyGCrystalGraphClassifier( + n_layers=n_layers, channels=channels, drop_out=drop_out, n_out=n_out + ) + self.criterion = torch.nn.BCELoss() + + # Training metrics + self.train_auc = AUROC(task="binary") + self.train_acc = Accuracy(task="binary") + + # Validation metrics + self.val_auc = AUROC(task="binary") + self.val_acc = Accuracy(task="binary") + + # Test metrics (ADD THESE!) + self.test_auc = AUROC(task="binary") + self.test_acc = Accuracy(task="binary") + + def forward(self, x, edge_index, edge_attr, batch): + return self.model(x, edge_index, edge_attr, batch).squeeze(-1) + + def training_step(self, batch, batch_idx): + out = self(batch.x, batch.edge_index, batch.edge_attr, batch.batch) + loss = self.criterion(out, batch.y.float()) + + self.train_auc(out, batch.y.int()) + self.train_acc(out, batch.y.int()) + self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + self.log( + "train_auc", self.train_auc, on_step=False, on_epoch=True, prog_bar=True + ) + self.log( + "train_acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True + ) + + return loss + + def validation_step(self, batch, batch_idx): + out = self(batch.x, batch.edge_index, batch.edge_attr, batch.batch) + loss = self.criterion(out, batch.y.float()) + + self.val_auc(out, batch.y.int()) + self.val_acc(out, batch.y.int()) + self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("val_auc", self.val_auc, on_step=False, on_epoch=True, prog_bar=True) + self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + + return loss + + def test_step(self, batch, batch_idx): + """Test step for evaluation""" + out = self(batch.x, batch.edge_index, batch.edge_attr, batch.batch) + loss = self.criterion(out, batch.y.float()) + + # Use the class-level test metrics + self.test_auc(out, batch.y.int()) + self.test_acc(out, batch.y.int()) + + self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("test_auc", self.test_auc, on_step=False, on_epoch=True, prog_bar=True) + self.log("test_acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) + + return loss + + def predict_step(self, batch, batch_idx): + """Prediction step - returns probabilities""" + out = self(batch.x, batch.edge_index, batch.edge_attr, batch.batch) + return out + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.hparams.lr, + weight_decay=self.hparams.weight_decay, + ) + + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=3 + ) + + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}, + } diff --git a/unravel/soccer/dataset/kloppy_polars.py b/unravel/soccer/dataset/kloppy_polars.py index 76f7b73c..79512dbf 100644 --- a/unravel/soccer/dataset/kloppy_polars.py +++ b/unravel/soccer/dataset/kloppy_polars.py @@ -229,36 +229,40 @@ def __apply_smoothing(self, df: pl.DataFrame, smoothing_params: dict): vy_smooth = f"{Column.VY}_smoothed" vz_smooth = f"{Column.VZ}_smoothed" + # DEBUG: Check group sizes + group_sizes = df.group_by(Group.BY_OBJECT_PERIOD).agg( + pl.col(Column.VX).count().alias("count") + ) + + window_length = smoothing_params["window_length"] + polyorder = smoothing_params["polyorder"] + + def apply_savgol(series): + """Apply savgol filter to a series (array of values).""" + values = series.to_numpy() + if len(values) < window_length: + return values.tolist() + return savgol_filter( + values, + window_length=window_length, + polyorder=polyorder, + ).tolist() + smoothed = df.group_by(Group.BY_OBJECT_PERIOD, maintain_order=True).agg( [ pl.col(Column.VX) - .map_elements( - lambda vx: savgol_filter( - vx, - window_length=smoothing_params["window_length"], - polyorder=smoothing_params["polyorder"], - ).tolist(), - return_dtype=pl.List(pl.Float64), + .map_batches( + apply_savgol, return_dtype=pl.List(pl.Float64), returns_scalar=True ) .alias(vx_smooth), pl.col(Column.VY) - .map_elements( - lambda vy: savgol_filter( - vy, - window_length=smoothing_params["window_length"], - polyorder=smoothing_params["polyorder"], - ).tolist(), - return_dtype=pl.List(pl.Float64), + .map_batches( + apply_savgol, return_dtype=pl.List(pl.Float64), returns_scalar=True ) .alias(vy_smooth), pl.col(Column.VZ) - .map_elements( - lambda vy: savgol_filter( - vy, - window_length=smoothing_params["window_length"], - polyorder=smoothing_params["polyorder"], - ).tolist(), - return_dtype=pl.List(pl.Float64), + .map_batches( + apply_savgol, return_dtype=pl.List(pl.Float64), returns_scalar=True ) .alias(vz_smooth), ] @@ -472,7 +476,7 @@ def __infer_ball_carrier(self, df: pl.DataFrame): ) # Update ball_owning_team if necessary ball_owning_team = (players_ball.drop(Column.BALL_OWNING_TEAM_ID)).join( - players_ball.group_by(Group.BY_FRAME) + players_ball.group_by(Group.BY_FRAME, maintain_order=True) .agg( [ pl.when((pl.col(Column.BALL_OWNING_TEAM_ID).is_null())) @@ -508,7 +512,7 @@ def __infer_ball_carrier(self, df: pl.DataFrame): ball_owning_team.filter( (pl.col(Column.BALL_OWNING_TEAM_ID) == pl.col(Column.TEAM_ID)) ) - .group_by(Group.BY_FRAME) + .group_by(Group.BY_FRAME, maintain_order=True) .agg( [ pl.when((pl.col(Column.BALL_OWNING_PLAYER_ID).is_null())) @@ -705,7 +709,7 @@ def load( df = df.filter(~(pl.col(Column.X).is_null() & pl.col(Column.Y).is_null())) if df[Column.BALL_OWNING_TEAM_ID].is_null().all(): - if self.ball_carrier_threshold is None: + if self._ball_carrier_threshold is None: raise ValueError( f"This dataset requires us to infer the {Column.BALL_OWNING_TEAM_ID}, please specifiy a ball_carrier_threshold (float) to do so." ) diff --git a/unravel/soccer/graphs/graph_converter.py b/unravel/soccer/graphs/graph_converter.py index bd897804..94a93f7f 100644 --- a/unravel/soccer/graphs/graph_converter.py +++ b/unravel/soccer/graphs/graph_converter.py @@ -3,7 +3,7 @@ from dataclasses import dataclass -from typing import List, Union, Dict, Literal, Any, Optional, Callable +from typing import List, Union, Dict, Literal, Any, Optional, Callable, TYPE_CHECKING import inspect @@ -11,7 +11,8 @@ from kloppy.domain import MetricPitchDimensions, Orientation -from spektral.data import Graph +if TYPE_CHECKING: + from spektral.data import Graph from .graph_settings import GraphSettingsPolars from ..dataset.kloppy_polars import KloppyPolarsDataset, Column, Group, Constant @@ -147,7 +148,7 @@ def _remove_incomplete_frames(self) -> pl.DataFrame: total_frames = len(df.unique(Group.BY_FRAME)) valid_frames = ( - df.group_by(Group.BY_FRAME) + df.group_by(Group.BY_FRAME, maintain_order=True) .agg(pl.col(Column.TEAM_ID).n_unique().alias("unique_teams")) .filter(pl.col("unique_teams") == 3) .select(Group.BY_FRAME) @@ -201,7 +202,7 @@ def _apply_padding(self) -> pl.DataFrame: + self.global_feature_cols ] - counts = df.group_by(group_by_columns).agg( + counts = df.group_by(group_by_columns, maintain_order=True).agg( pl.len().alias("count"), *[ pl.first(col).alias(col) @@ -322,7 +323,7 @@ def _apply_padding(self) -> pl.DataFrame: total_frames = result.select(Group.BY_FRAME).unique().height frame_completeness = ( - result.group_by(Group.BY_FRAME) + result.group_by(Group.BY_FRAME, maintain_order=True) .agg( [ (pl.col(Column.TEAM_ID).eq(Constant.BALL).sum() == 1).alias( @@ -588,15 +589,9 @@ def _compute(self, args: List[pl.Series]) -> dict: ) return { - "e": pl.Series( - [edge_features.tolist()], dtype=pl.List(pl.List(pl.Float64)) - ), - "x": pl.Series( - [node_features.tolist()], dtype=pl.List(pl.List(pl.Float64)) - ), - "a": pl.Series( - [adjacency_matrix.tolist()], dtype=pl.List(pl.List(pl.Int32)) - ), + "e": edge_features.tolist(), # Remove pl.Series wrapper + "x": node_features.tolist(), # Remove pl.Series wrapper + "a": adjacency_matrix.tolist(), # Remove pl.Series wrapper "e_shape_0": edge_features.shape[0], "e_shape_1": edge_features.shape[1], "x_shape_0": node_features.shape[0], @@ -606,9 +601,9 @@ def _compute(self, args: List[pl.Series]) -> dict: self.graph_id_column: frame_data[self.graph_id_column][0], self.label_column: frame_data[self.label_column][0], "frame_id": frame_id, - "object_ids": pl.Series( - [frame_data[Column.OBJECT_ID].tolist()], dtype=pl.List(pl.String) - ), + "object_ids": frame_data[ + Column.OBJECT_ID + ].tolist(), # Remove pl.Series wrapper "ball_owning_team_id": ball_owning_team_id, } @@ -621,6 +616,7 @@ def _convert(self): exprs=self._exprs_variables + [Column.FRAME_ID], function=self._compute, return_dtype=self.return_dtypes, + returns_scalar=True, ).alias("result_dict") ) .with_columns( @@ -899,7 +895,7 @@ def plot_graph(): """Plot graph features (node features, adjacency matrix, edge features)""" import matplotlib.pyplot as plt - num_rows = self._graph.x.shape[0] + num_rows = self._graph["x"].shape[0] labels = ( [ @@ -908,7 +904,7 @@ def plot_graph(): if pid != Constant.BALL else Constant.BALL ) - for pid in self._graph.object_ids + for pid in self._graph["object_ids"] ] if not anonymous else [str(i) for i in range(num_rows)] @@ -926,8 +922,8 @@ def plot_graph(): # Plot node features ax1 = self._fig.add_subplot(self._gs[node_pos]) - ax1.imshow(self._graph.x, aspect="auto", cmap="YlOrRd") - ax1.set_xlabel(f"Node Features {self._graph.x.shape}") + ax1.imshow(self._graph["x"], aspect="auto", cmap="YlOrRd") + ax1.set_xlabel(f"Node Features {self._graph['x'].shape}") # Set y labels to integers ax1.set_yticks(range(num_rows)) @@ -940,12 +936,12 @@ def plot_graph(): # Plot adjacency matrix ax2 = self._fig.add_subplot(self._gs[adj_pos]) - ax2.imshow(self._graph.a.toarray(), aspect="auto", cmap="YlOrRd") - ax2.set_xlabel(f"Adjacency Matrix {self._graph.a.shape}") + ax2.imshow(self._graph["a"].toarray(), aspect="auto", cmap="YlOrRd") + ax2.set_xlabel(f"Adjacency Matrix {self._graph['a'].shape}") # Set both x and y labels to integers - num_rows_a = self._graph.a.toarray().shape[0] - num_cols_a = self._graph.a.toarray().shape[1] + num_rows_a = self._graph["a"].toarray().shape[0] + num_cols_a = self._graph["a"].toarray().shape[1] ax2.set_yticks(range(num_rows_a)) ax2.set_yticklabels(labels) @@ -956,15 +952,19 @@ def plot_graph(): # Plot Edge Features ax3 = self._fig.add_subplot(self._gs[edge_pos]) - _, size_a = non_zeros(self._graph.a.toarray()[0 : self._ball_carrier_idx]) + _, size_a = non_zeros( + self._graph["a"].toarray()[0 : self._ball_carrier_idx] + ) ball_carrier_edge_idx, num_rows_e = non_zeros( np.asarray( - [list(x) for x in self._graph.a.toarray()][self._ball_carrier_idx] + [list(x) for x in self._graph["a"].toarray()][ + self._ball_carrier_idx + ] ) ) im3 = ax3.imshow( - self._graph.e[size_a : num_rows_e + size_a, :], + self._graph["e"][size_a : num_rows_e + size_a, :], aspect="auto", cmap="YlOrRd", ) @@ -972,7 +972,7 @@ def plot_graph(): ax3.set_yticks(range(num_rows_e)) ax3.set_yticklabels(list(ball_carrier_edge_idx[0]), fontsize=18) ball_carrier_edge_idxs = list(ball_carrier_edge_idx[0]) - ax3.set_xlabel(f"Edge Features {self._graph.e.shape}") + ax3.set_xlabel(f"Edge Features {self._graph['e'].shape}") ax3_labels = ax3.get_yticklabels() if self._ball_carrier_idx in ball_carrier_edge_idx[0]: @@ -1144,6 +1144,7 @@ def player_and_ball(frame_data, ax): ax.set_title(self._gameclock, fontsize=22) def frame_plot(self, frame_data): + def timestamp_to_gameclock(timestamp, period_id): total_seconds = timestamp.total_seconds() @@ -1175,15 +1176,15 @@ def timestamp_to_gameclock(timestamp, period_id): ) y = np.asarray([features[self.label_column]]) - self._graph = Graph( - a=a, - x=x, - e=e, - y=y, - frame_id=features["frame_id"], - object_ids=frame_data[Column.OBJECT_ID], - ball_owning_team_id=frame_data[Column.BALL_OWNING_TEAM_ID][0], - ) + self._graph = { + "a": a, + "x": x, + "e": e, + "y": y, + "frame_id": features["frame_id"], + "object_ids": frame_data[Column.OBJECT_ID], + "ball_owning_team_id": frame_data[Column.BALL_OWNING_TEAM_ID][0], + } self._ball_carrier_idx = np.where( frame_data[Column.IS_BALL_CARRIER] == True diff --git a/unravel/soccer/models/formations/efpi.py b/unravel/soccer/models/formations/efpi.py index eb02fed4..5e5c5242 100644 --- a/unravel/soccer/models/formations/efpi.py +++ b/unravel/soccer/models/formations/efpi.py @@ -220,6 +220,17 @@ def _compute(self, args: List[pl.Series], **kwargs) -> pl.DataFrame: object_ids=d[Column.OBJECT_ID].tolist(), team_ids=d[Column.TEAM_ID].tolist() ) + @property + def return_dtypes(self): + return pl.Struct( + { + Column.OBJECT_ID: pl.List(pl.String), + Column.TEAM_ID: pl.List(pl.String), + "position": pl.List(pl.String), + "formation": pl.List(pl.String), + } + ) + def fit( self, start_time: pl.duration = None, @@ -274,7 +285,8 @@ def fit( pl.map_groups( exprs=self._exprs_variables, function=lambda group: self._compute(group), - return_dtype=pl.Struct, + return_dtype=self.return_dtypes, + returns_scalar=True, ).alias("result") ) .unnest("result") @@ -362,9 +374,12 @@ def fit( & (pl.col(Column.POSITION_NAME) != "GK") ) .group_by( - [Column.GAME_ID, Column.PERIOD_ID, Column.TEAM_ID, segment_id] - if self._every != "period" - else [Column.GAME_ID, Column.PERIOD_ID, Column.TEAM_ID] + ( + [Column.GAME_ID, Column.PERIOD_ID, Column.TEAM_ID, segment_id] + if self._every != "period" + else [Column.GAME_ID, Column.PERIOD_ID, Column.TEAM_ID] + ), + maintain_order=True, ) .agg([pl.col(Column.OBJECT_ID).n_unique().alias("objects")]) .sort([segment_id]) @@ -391,7 +406,7 @@ def fit( segment_id, ], ) - .group_by(columns) + .group_by(columns, maintain_order=True) .agg([pl.len().alias("length")]) .with_columns( pl.col("length") @@ -422,9 +437,12 @@ def fit( segment_coordinates = ( df1.group_by( - group_by_columns + [segment_id] - if self._every != "period" - else group_by_columns + ( + group_by_columns + [segment_id] + if self._every != "period" + else group_by_columns + ), + maintain_order=True, ) .agg( [ @@ -467,7 +485,8 @@ def fit( pl.map_groups( exprs=self._exprs_variables, function=lambda group: self._compute(group), - return_dtype=pl.Struct, + return_dtype=self.return_dtypes, + returns_scalar=True, ).alias("result") ) .unnest("result") diff --git a/unravel/soccer/models/pressing_intensity.py b/unravel/soccer/models/pressing_intensity.py index a98c9382..bd5998df 100644 --- a/unravel/soccer/models/pressing_intensity.py +++ b/unravel/soccer/models/pressing_intensity.py @@ -235,6 +235,17 @@ def _set_minimum(matrix, ball_carrier_idx, ball_idx): "rows": row_objects.tolist(), } + @property + def __get_return_dtype(self): + return pl.Struct( + { + "time_to_intercept": pl.List(pl.List(pl.Float64)), + "probability_to_intercept": pl.List(pl.List(pl.Float64)), + "columns": pl.List(pl.String), + "rows": pl.List(pl.String), + } + ) + def fit( self, start_time: pl.duration = None, @@ -354,6 +365,8 @@ def fit( pl.map_groups( exprs=self.__exprs_variables, function=self.__compute, + return_dtype=self.__get_return_dtype, + returns_scalar=True, ).alias("results") ) .unnest("results") diff --git a/unravel/utils/exceptions/exceptions.py b/unravel/utils/exceptions/exceptions.py index cb5d134d..8709ea2d 100644 --- a/unravel/utils/exceptions/exceptions.py +++ b/unravel/utils/exceptions/exceptions.py @@ -28,3 +28,20 @@ class AdjcacenyMatrixTypeNotSetException(Exception): class KeyMismatchException(Exception): pass + + +class SpektralDependencyError(ImportError): + """Raised when Spektral or its dependencies are not properly installed.""" + + def __init__(self): + self.message = ( + "Seems like you don't have spektral installed.\n\n" + "Requirements:\n" + " - Python 3.11 (recommended)\n\n" + "Installation:\n" + " pip install spektral==1.2.0 keras==2.14.0 && " + "(pip install tensorflow>=2.14.0 || pip install tensorflow-macos>=2.14.0)" + "\nWarning:\n" + " If you want to use Spektral, it is advised to use unravelsports v1.1.0 or below and Python3.11. Or, simply continue using PyTorch functionality instead." + ) + super().__init__(self.message) diff --git a/unravel/utils/features/utils.py b/unravel/utils/features/utils.py index 9a6441d1..46e8fd83 100644 --- a/unravel/utils/features/utils.py +++ b/unravel/utils/features/utils.py @@ -211,6 +211,12 @@ def flatten_to_reshaped_array(arr, s0, s1, as_list=False): return result_array if not as_list else result_array.tolist() +def flatten_to_reshaped_array(arr, s0, s1, as_list=False): + # Convert to numpy array directly + result_array = np.array(arr).reshape(s0, s1) + return result_array if not as_list else result_array.tolist() + + def reshape_array(arr): return np.array([a for a in arr.to_numpy()]) diff --git a/unravel/utils/objects/default_graph_converter.py b/unravel/utils/objects/default_graph_converter.py index ca854f91..ec002721 100644 --- a/unravel/utils/objects/default_graph_converter.py +++ b/unravel/utils/objects/default_graph_converter.py @@ -6,14 +6,16 @@ import polars as pl -from typing import List, Union, Dict, Literal +from typing import List, Union, Dict, Literal, TYPE_CHECKING from kloppy.domain import TrackingDataset -from spektral.data import Graph +if TYPE_CHECKING: + from spektral.data import Graph + from torch_geometric.data import Data from ..exceptions import ( - KeyMismatchException, + SpektralDependencyError, ) from ..features import ( AdjacencyMatrixType, @@ -67,6 +69,7 @@ class DefaultGraphConverter: verbose (bool): The converter logs warnings / error messages when specific frames have no coordinates, or other missing information. False mutes all these warnings. """ + engine: Literal["auto", "gpu"] = "auto" prediction: bool = False self_loop_ball: bool = False @@ -169,7 +172,64 @@ def _apply_graph_settings(self): def _convert(self): raise NotImplementedError() - def to_spektral_graphs(self, include_object_ids: bool = False) -> List[Graph]: + def to_pytorch_graphs( + self, include_object_ids: bool = False + ) -> List["torch_geometric.data.Data"]: + """ + Convert graph frames to PyTorch Geometric Data objects. + + Returns: + List of torch_geometric.data.Data objects + """ + try: + import torch + from torch_geometric.data import Data + except ImportError: + raise ImportError( + "PyTorch Geometric is required for this functionality. " + "Install it with: pip install torch torch-geometric" + ) + + if not self.graph_frames: + self.to_graph_frames(include_object_ids) + + pyg_graphs = [] + for d in self.graph_frames: + x = torch.tensor(d["x"], dtype=torch.float) + + a = d["a"].toarray() if hasattr(d["a"], "toarray") else d["a"] + + edge_indices = np.nonzero(a) + edge_index = torch.tensor(np.vstack(edge_indices), dtype=torch.long) + + edge_attr = torch.tensor(d["e"], dtype=torch.float) + + y = torch.tensor(d["y"], dtype=torch.long) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) + + data.id = d["id"] + data.frame_id = d["frame_id"] + data.ball_owning_team_id = d.get("ball_owning_team_id", None) + + if include_object_ids: + data.object_ids = d["object_ids"] + + pyg_graphs.append(data) + + return pyg_graphs + + def to_pyg_graphs(self, include_object_ids: bool = False): + return self.to_pytorch_graphs(include_object_ids) + + def to_spektral_graphs( + self, include_object_ids: bool = False + ) -> List["spektral.data.Graph"]: + try: + from spektral.data import Graph + except ImportError: + raise SpektralDependencyError() + if not self.graph_frames: self.to_graph_frames(include_object_ids) @@ -258,7 +318,7 @@ def return_dtypes(self): { "e": pl.List(pl.List(pl.Float64)), "x": pl.List(pl.List(pl.Float64)), - "a": pl.List(pl.List(pl.Float64)), + "a": pl.List(pl.List(pl.Int32)), "e_shape_0": pl.Int64, "e_shape_1": pl.Int64, "x_shape_0": pl.Int64, @@ -267,8 +327,9 @@ def return_dtypes(self): "a_shape_1": pl.Int64, self.graph_id_column: pl.String, self.label_column: pl.Int64, - "object_ids": pl.List(pl.List(pl.String)), - # "frame_id": pl.String + "object_ids": pl.List(pl.String), + "frame_id": pl.Int64, + "ball_owning_team_id": pl.String, } ) @@ -278,6 +339,12 @@ def __convert_object_ids(objects): # convert padded players to None return [x if x != "" else None for x in objects] + # In process_chunk, before the reshape_from_size call: + if chunk["a"][0] is None: + print(f"Row {0}: a is None") + print(f"Full row: {chunk.row(0, named=True)}") + raise ValueError(f"Unexpected None value at row {0}") + return [ { **{ @@ -306,7 +373,7 @@ def __convert_object_ids(objects): **( { "object_ids": __convert_object_ids( - list(chunk["object_ids"][i][0]) + list(chunk["object_ids"][i]) ) } if include_object_ids @@ -320,7 +387,7 @@ def __convert_object_ids(objects): self.graph_frames = [ graph for chunk in graph_df.lazy() - .collect(engine="gpu") + .collect(engine=self.engine) .iter_slices(self.chunk_size) for graph in process_chunk(chunk) ] diff --git a/unravel/utils/objects/graph_dataset.py b/unravel/utils/objects/graph_dataset.py index 59dd5922..90d438bf 100644 --- a/unravel/utils/objects/graph_dataset.py +++ b/unravel/utils/objects/graph_dataset.py @@ -1,44 +1,45 @@ import logging import sys -from typing import List, Tuple, Union, Optional +from typing import List, Tuple, Union, Optional, Literal import numpy as np -import random - import gzip import pickle from pathlib import Path import warnings -import tensorflow as tf - from collections.abc import Sequence -from spektral.data import Dataset, Graph -from spektral.data.utils import get_spec +from unravel.utils.exceptions import NoGraphIdsWarning, SpektralDependencyError -from ..exceptions import NoGraphIdsWarning - -# Function to load data from a .pickle.gz file def load_pickle_gz(file_path): with gzip.open(file_path, "rb") as f: data = pickle.load(f) return data -class GraphDataset(Dataset, Sequence): +class _GraphDatasetMixin: """ - A GraphDataset is required to use all Spektral funcitonality, see 'spektral.data -> Dataset' + Base mixin for graph dataset functionality. + Framework-agnostic implementation that works with both Spektral and PyTorch Geometric. """ def __init__(self, **kwargs): """ Constructor to load parameters. + + Args: + pickle_folder: Path to folder containing .pickle.gz files + pickle_file: Path to single .pickle.gz file + graphs: List of graph objects (Spektral Graph, PyG Data, or dicts) + format: Optional explicit format specification ('spektral' or 'pyg') + sample_rate: Sampling rate (1.0 = use all data) """ self._kwargs = kwargs + self._explicit_format = kwargs.get("format", None) sample_rate = kwargs.get("sample_rate", 1.0) self.sample = 1.0 / sample_rate @@ -68,49 +69,37 @@ def __init__(self, **kwargs): if not isinstance(kwargs["graphs"], list): raise NotImplementedError("""data should be of type list""") - self.graphs = kwargs["graphs"] + self.graphs = self.__convert(kwargs["graphs"]) else: raise NotImplementedError( "Please provide either 'pickle_folder', 'pickle_file' or 'graphs' as parameter to GraphDataset" ) - super().__init__(**kwargs) + # Only call super().__init__ if there's a parent class that needs it + # For PyGGraphDataset, Sequence doesn't take kwargs + # For SpektralGraphDataset, Dataset does take kwargs + try: + super().__init__(**kwargs) + except TypeError: + # If super().__init__() doesn't accept kwargs (like Sequence), call it without args + super().__init__() - def __convert(self, data) -> List[Graph]: + def __convert(self, data): """ - Convert incoming data to correct List[Graph] format + Convert incoming data to correct format. + Must be implemented by subclasses. """ - if isinstance(data[0], Graph): - return [g for i, g in enumerate(data) if i % self.sample == 0] - elif isinstance(data[0], dict): - return [ - Graph( - x=g["x"], - a=g["a"], - e=g["e"], - y=g["y"], - id=g["id"], - frame_id=g.get("frame_id", None), - object_ids=g.get("object_ids", None), - ball_owning_team_id=g.get("ball_owning_team_id", None), - ) - for i, g in enumerate(data) - if i % self.sample == 0 - ] - else: - raise NotImplementedError() + raise NotImplementedError("Subclasses must implement __convert()") - def read(self) -> List[Graph]: + def read(self): """ - Overriding the read function - to return a list of Graph objects + Overriding the read function - to return a list of Graph objects. + Must be implemented by subclasses. """ - graphs = self.__convert(self.graphs) - - logging.info(f"Loading {len(graphs)} graphs into GraphDataset...") - - return graphs + raise NotImplementedError("Subclasses must implement read()") def add(self, other, verbose: bool = False): + """Add more graphs to the dataset""" other = self.__convert(other) if verbose: @@ -123,15 +112,10 @@ def dimensions(self) -> Tuple[int, int, int, int, int]: N = Max number of nodes F = Dimensions of Node Features S = Dimensions of Edge Features - n_out = Dimesion of the target + n_out = Dimension of the target n = Number of samples in dataset """ - N = max(g.n_nodes for g in self) - F = self.n_node_features - S = self.n_edge_features - n_out = self.n_labels - n = len(self) - return (N, F, S, n_out, n) + raise NotImplementedError("Subclasses must implement dimensions()") def split_test_train( self, @@ -165,25 +149,7 @@ def split_test_train_validation( ): """ Split dataset into train, test, and validation sets with optional label balancing. - - split_train (float): amount of total samples that will go into train set - split_test (float): amount of total samples that will go into test set. - split_validation (float): amount of total samples that will go into validation set. Defaults to 0.0. - by_graph_id (bool): when we want to split the samples by graph_id, such that all graphs with the same id end up in the same train/test/validation set - set to True. Defaults to False. When set to True the split ratio's will be approximated, - because we can't be sure to split the graphs exactly according to the ratios. - random_seed (int, optional): Random seed for reproducibility - train_label_ratio (float, optional): If provided, balances the training set to have this ratio of labels (0/1). - Must be between 0 and 1. Defaults to None (keep original distribution). - test_label_ratio (float, optional): If provided, balances the test set to have this ratio of labels (0/1). - Must be between 0 and 1. Defaults to None (keep original distribution). - val_label_ratio (float, optional): If provided, balances the validation set to have this ratio of labels (0/1). - Must be between 0 and 1. Defaults to None (keep original distribution). - - for an explanation on splitting behaviour when by_graph_id = True - see: https://github.com/USSoccerFederation/ussf_ssac_23_soccer_gnn/blob/main/split_sequences.py """ - total = split_train + split_test + split_validation train_pct = split_train / total @@ -210,7 +176,10 @@ def split_test_train_validation( num_validation = 0 unique_graph_ids = set( - [g.get("id") if hasattr(g, "id") else None for g in self] + [ + g.get("id") if hasattr(g, "id") else getattr(g, "graph_id", None) + for g in self + ] ) if unique_graph_ids == {None}: by_graph_id = False @@ -221,7 +190,6 @@ def split_test_train_validation( ) if not by_graph_id: - # if we don't use the graph_ids we simply shuffle all indices and return 2 or 3 randomly shuffled datasets if random_seed: idxs = np.random.RandomState(seed=random_seed).permutation( dataset_length @@ -240,7 +208,6 @@ def split_test_train_validation( test_set = self[test_idxs] validation_set = self[validation_idxs] - # Apply label balancing if requested if train_label_ratio is not None: train_set = self._balance_labels( train_set, train_label_ratio, random_seed @@ -262,7 +229,6 @@ def split_test_train_validation( train_set = self[train_idxs] test_set = self[test_idxs] - # Apply label balancing if requested if train_label_ratio is not None: train_set = self._balance_labels( train_set, train_label_ratio, random_seed @@ -274,9 +240,17 @@ def split_test_train_validation( return train_set, test_set else: - # if we do use the graph_ids we randomly assign all items of a certain graph_id to either - # val, test or train. We start with validation, because it's assumed to be the smallest dataset. - graph_ids = np.asarray([g.get("id")[0] for g in self]) + # Get graph IDs in a framework-agnostic way + graph_ids = np.asarray( + [ + ( + g.get("id") + if hasattr(g, "get") and g.get("id") is not None + else getattr(g, "graph_id", None) + ) + for g in self + ] + ) if random_seed: np.random.seed(random_seed) @@ -312,7 +286,6 @@ def __handle_graph_id(i): test_set = self[test_idxs] validation_set = self[validation_idxs] - # Apply label balancing if requested if train_label_ratio is not None: train_set = self._balance_labels( train_set, train_label_ratio, random_seed @@ -331,7 +304,6 @@ def __handle_graph_id(i): train_set = self[train_idxs] test_set = self[test_idxs] - # Apply label balancing if requested if train_label_ratio is not None: train_set = self._balance_labels( train_set, train_label_ratio, random_seed @@ -344,80 +316,53 @@ def __handle_graph_id(i): return train_set, test_set def _balance_labels(self, dataset, target_ratio, random_seed): - """ - Balance a dataset to achieve a target ratio of labels. - - Args: - dataset: A GraphDataset containing Graph objects - target_ratio: Float between 0 and 1, representing the desired ratio of positive labels - (e.g., 0.5 for a 50/50 split) - - Returns: - A balanced subset of the dataset with the desired label ratio - """ + """Balance a dataset to achieve a target ratio of labels.""" if random_seed: np.random.seed(random_seed) if not 0 <= target_ratio <= 1: raise ValueError("target_ratio must be between 0 and 1") - # Identify indices by label indices_by_label = {0: [], 1: []} for i, g in enumerate(dataset): # Handle different types of label storage if hasattr(g, "y"): - if isinstance(g.y, (np.ndarray, list)): - # Check that y is not longer than 1 item - if len(g.y) != 1: - raise ValueError( - f"Expected y to be a single value, but got array of length {len(g.y)}" - ) - label = 1 if g.y[0] > 0.5 else 0 - else: - label = 1 if g.y > 0.5 else 0 - elif g.get("y", None) is not None: - # If using dictionary access + y_value = g.y + elif hasattr(g, "get") and g.get("y", None) is not None: y_value = g["y"] - if isinstance(y_value, (np.ndarray, list)): - # Check that y is not longer than 1 item - if len(y_value) != 1: - raise ValueError( - f"Expected y to be a single value, but got array of length {len(y_value)}" - ) - label = 1 if y_value[0] > 0.5 else 0 - else: - label = 1 if y_value > 0.5 else 0 else: raise ValueError("Graph has no attribute 'y'...") + if isinstance(y_value, (np.ndarray, list)): + if len(y_value) != 1: + raise ValueError( + f"Expected y to be a single value, but got array of length {len(y_value)}" + ) + label = 1 if y_value[0] > 0.5 else 0 + else: + label = 1 if y_value > 0.5 else 0 + indices_by_label[label].append(i) - # Count samples for each class n_zeros = len(indices_by_label[0]) n_ones = len(indices_by_label[1]) total = n_zeros + n_ones - # Calculate current ratio current_ratio = n_ones / total if total > 0 else 0 - # If already matching target ratio (within 1%), return as is if abs(current_ratio - target_ratio) < 0.01: return dataset - # Calculate how many samples we need for each class if current_ratio > target_ratio: - # Too many positives, keep all negatives target_ones = int(n_zeros * target_ratio / (1 - target_ratio)) target_zeros = n_zeros else: - # Too many negatives, keep all positives target_zeros = int(n_ones * (1 - target_ratio) / target_ratio) target_ones = n_ones indices_to_keep = [] - # Keep samples from class 0 (negative) if n_zeros > target_zeros: sampled_zeros = np.random.choice( indices_by_label[0], target_zeros, replace=False @@ -426,7 +371,6 @@ def _balance_labels(self, dataset, target_ratio, random_seed): else: indices_to_keep.extend(indices_by_label[0]) - # Keep samples from class 1 (positive) if n_ones > target_ones: sampled_ones = np.random.choice( indices_by_label[1], target_ones, replace=False @@ -435,55 +379,341 @@ def _balance_labels(self, dataset, target_ratio, random_seed): else: indices_to_keep.extend(indices_by_label[1]) - # Shuffle indices np.random.shuffle(indices_to_keep) - # Return a subset of the dataset using the balanced indices return dataset[indices_to_keep] - @property - def signature(self): - """ - This property computes the signature of the dataset, which can be - passed to `spektral.data.utils.to_tf_signature(signature)` to compute - the TensorFlow signature. - The signature includes TensorFlow TypeSpec, shape, and dtype for all - characteristic matrices of the graphs in the Dataset. +# ============================================================================= +# SPEKTRAL IMPLEMENTATION +# ============================================================================= + +try: + from spektral.data import Dataset, Graph + from spektral.data.utils import get_spec + import tensorflow as tf + + _HAS_SPEKTRAL = True + + class SpektralGraphDataset(_GraphDatasetMixin, Dataset, Sequence): """ - if len(self.graphs) == 0: - return None - signature = {} - graph = self.graphs[0] # This is always non-empty - - if graph.x is not None: - signature["x"] = dict() - signature["x"]["spec"] = get_spec(graph.x) - signature["x"]["shape"] = (None, self.n_node_features) - signature["x"]["dtype"] = tf.as_dtype(graph.x.dtype) - - if graph.a is not None: - signature["a"] = dict() - signature["a"]["spec"] = get_spec(graph.a) - signature["a"]["shape"] = (None, None) - signature["a"]["dtype"] = tf.as_dtype(graph.a.dtype) - - if graph.e is not None: - signature["e"] = dict() - signature["e"]["spec"] = get_spec(graph.e) - signature["e"]["shape"] = (None, self.n_edge_features) - signature["e"]["dtype"] = tf.as_dtype(graph.e.dtype) - - if graph.y is not None: - signature["y"] = dict() - signature["y"]["spec"] = get_spec(graph.y) - signature["y"]["shape"] = (self.n_labels,) - signature["y"]["dtype"] = tf.as_dtype(np.array(graph.y).dtype) - - if hasattr(graph, "g") and graph.g is not None: - signature["g"] = dict() - signature["g"]["spec"] = get_spec(graph.g) - signature["g"]["shape"] = graph.g.shape - signature["g"]["dtype"] = tf.as_dtype(np.array(graph.g).dtype) - - return signature + Spektral-specific GraphDataset implementation. + """ + + def _SpektralGraphDataset__convert(self, data) -> List: + """Convert incoming data to Spektral Graph format""" + from spektral.data import Graph + + if isinstance(data[0], Graph): + return [g for i, g in enumerate(data) if i % self.sample == 0] + elif isinstance(data[0], dict): + return [ + Graph( + x=g["x"], + a=g["a"], + e=g["e"], + y=g["y"], + id=g["id"], + frame_id=g.get("frame_id", None), + object_ids=g.get("object_ids", None), + ball_owning_team_id=g.get("ball_owning_team_id", None), + ) + for i, g in enumerate(data) + if i % self.sample == 0 + ] + else: + raise ValueError( + f"Cannot convert type {type(data[0])} to Spektral Graph. " + "Expected Spektral Graph or dict." + ) + + _GraphDatasetMixin__convert = _SpektralGraphDataset__convert + + def read(self) -> List: + """Return a list of Spektral Graph objects""" + graphs = self._SpektralGraphDataset__convert(self.graphs) + logging.info(f"Loading {len(graphs)} graphs into SpektralGraphDataset...") + return graphs + + def dimensions(self) -> Tuple[int, int, int, int, int]: + """N, F, S, n_out, n""" + N = max(g.n_nodes for g in self) + F = self.n_node_features + S = self.n_edge_features + n_out = self.n_labels + n = len(self) + return (N, F, S, n_out, n) + + @property + def signature(self): + """Compute TensorFlow signature for the dataset""" + from spektral.data.utils import get_spec + import tensorflow as tf + + if len(self.graphs) == 0: + return None + signature = {} + graph = self.graphs[0] + + if graph.x is not None: + signature["x"] = dict() + signature["x"]["spec"] = get_spec(graph.x) + signature["x"]["shape"] = (None, self.n_node_features) + signature["x"]["dtype"] = tf.as_dtype(graph.x.dtype) + + if graph.a is not None: + signature["a"] = dict() + signature["a"]["spec"] = get_spec(graph.a) + signature["a"]["shape"] = (None, None) + signature["a"]["dtype"] = tf.as_dtype(graph.a.dtype) + + if graph.e is not None: + signature["e"] = dict() + signature["e"]["spec"] = get_spec(graph.e) + signature["e"]["shape"] = (None, self.n_edge_features) + signature["e"]["dtype"] = tf.as_dtype(graph.e.dtype) + + if graph.y is not None: + signature["y"] = dict() + signature["y"]["spec"] = get_spec(graph.y) + signature["y"]["shape"] = (self.n_labels,) + signature["y"]["dtype"] = tf.as_dtype(np.array(graph.y).dtype) + + if hasattr(graph, "g") and graph.g is not None: + signature["g"] = dict() + signature["g"]["spec"] = get_spec(graph.g) + signature["g"]["shape"] = graph.g.shape + signature["g"]["dtype"] = tf.as_dtype(np.array(graph.g).dtype) + + return signature + +except ImportError: + _HAS_SPEKTRAL = False + + # Create a dummy class that raises an informative error + class SpektralGraphDataset: + def __init__(self, *args, **kwargs): + raise SpektralDependencyError() + + +# ============================================================================= +# PYTORCH GEOMETRIC IMPLEMENTATION +# ============================================================================= + +try: + import torch + from torch_geometric.data import Data + + _HAS_TORCH_GEOMETRIC = True +except ImportError: + _HAS_TORCH_GEOMETRIC = False + + +class PyGGraphDataset(_GraphDatasetMixin, Sequence): + """ + PyTorch Geometric GraphDataset implementation. + """ + + def _PyGGraphDataset__convert(self, data) -> List: + """Convert incoming data to PyG Data format""" + if not _HAS_TORCH_GEOMETRIC: + raise ImportError( + "PyTorch Geometric is required for PyGGraphDataset. " + "Install it using: pip install torch torch-geometric" + ) + + from torch_geometric.data import Data + + if isinstance(data[0], Data): + return [g for i, g in enumerate(data) if i % self.sample == 0] + elif isinstance(data[0], dict): + pyg_graphs = [] + for i, d in enumerate(data): + if i % self.sample != 0: + continue + + # Node features + x = torch.tensor(d["x"], dtype=torch.float) + + # Get adjacency matrix and convert to edge_index + a = d["a"].toarray() if hasattr(d["a"], "toarray") else d["a"] + edge_indices = np.nonzero(a) + edge_index = torch.tensor(np.vstack(edge_indices), dtype=torch.long) + + # Edge features (already aligned with edges) + edge_attr = torch.tensor(d["e"], dtype=torch.float) + + # Labels + y = torch.tensor(d["y"], dtype=torch.long) + + # Create Data object + graph_data = Data( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + y=y, + ) + + # Add custom attributes + graph_data.id = d.get("id", None) + graph_data.frame_id = d.get("frame_id", None) + graph_data.ball_owning_team_id = d.get("ball_owning_team_id", None) + graph_data.object_ids = d.get("object_ids", None) + + pyg_graphs.append(graph_data) + + return pyg_graphs + else: + raise ValueError( + f"Cannot convert type {type(data[0])} to PyG Data. " + "Expected PyG Data or dict." + ) + + _GraphDatasetMixin__convert = _PyGGraphDataset__convert + + def read(self) -> List: + """Return a list of PyG Data objects""" + if not _HAS_TORCH_GEOMETRIC: + raise ImportError( + "PyTorch Geometric is required. " + "Install it using: pip install torch torch-geometric" + ) + + graphs = self._PyGGraphDataset__convert(self.graphs) + logging.info(f"Loading {len(graphs)} graphs into PyGGraphDataset...") + return graphs + + def dimensions(self) -> Tuple[int, int, int, int, int]: + """N, F, S, n_out, n""" + N = max(data.num_nodes for data in self) + F = self[0].num_node_features if len(self) > 0 else 0 + S = self[0].num_edge_features if len(self) > 0 else 0 + n_out = self[0].y.shape[0] if len(self) > 0 else 0 + n = len(self) + return (N, F, S, n_out, n) + + def __len__(self): + return len(self.graphs) + + def __getitem__(self, idx): + if isinstance(idx, (list, np.ndarray)): + selected_graphs = [self.graphs[i] for i in idx] + return PyGGraphDataset(graphs=selected_graphs, sample_rate=1.0) + else: + return self.graphs[idx] + + def __repr__(self): + return f"PyGGraphDataset(n_graphs={len(self)})" + + @property + def n_graphs(self): + return len(self) + + +def GraphDataset( + format: Optional[Literal["spektral", "pyg"]] = "spektral", **kwargs +) -> Union[SpektralGraphDataset, PyGGraphDataset]: + """ + Factory function that automatically detects and creates the appropriate dataset. + + Args: + format: Optional format specification ('spektral' or 'pyg'). + Only required when passing dict format graphs or pickle files. + For Spektral Graph or PyG Data objects, format is auto-detected. + **kwargs: Arguments passed to the dataset constructor + + Returns: + SpektralGraphDataset or PyGGraphDataset depending on format + + Examples: + # Auto-detect from Spektral graphs + dataset = GraphDataset(graphs=spektral_graph_list) + + # Auto-detect from PyG graphs + dataset = GraphDataset(graphs=pyg_data_list) + + # Explicit format required for dicts + dataset = GraphDataset(graphs=dict_list, format='pyg') + + # Explicit format required for pickle files + dataset = GraphDataset(pickle_file='graphs.pickle.gz', format='spektral') + """ + import warnings + + if format == "spektral": + warnings.warn( + """ +unravelsports now supports PyTorch Geometric. The default "format" will change from 'spektral' to 'pyg' in a future version. +\nNote: format='spektral' only really works on Python 3.11, due to very specific package requirements. PyTorch works on 3.11+. +""", + FutureWarning, + ) + + def _create_dataset(fmt: str): + """Helper function to create the appropriate dataset""" + if fmt.lower() == "spektral": + if not _HAS_SPEKTRAL: + raise SpektralDependencyError() + return SpektralGraphDataset(**kwargs) + elif fmt.lower() == "pyg": + if not _HAS_TORCH_GEOMETRIC: + raise ImportError( + "PyTorch Geometric is required. " + "Install it using: pip install torch torch-geometric" + ) + return PyGGraphDataset(**kwargs) + else: + raise ValueError(f"format must be 'spektral' or 'pyg', got '{fmt}'") + + # Auto-detect from graphs if provided + if kwargs.get("graphs", None) is not None: + graphs = kwargs["graphs"] + + if not isinstance(graphs, list) or len(graphs) == 0: + raise ValueError("graphs must be a non-empty list") + + first_item = graphs[0] + + # Check if it's a dict - require explicit format + if isinstance(first_item, dict): + if format is None: + raise ValueError( + "When passing dict format graphs, you must explicitly specify format='spektral' or format='pyg'" + ) + return _create_dataset(format) + + # Check if it's a Spektral Graph + if _HAS_SPEKTRAL: + from spektral.data import Graph + + if isinstance(first_item, Graph): + return SpektralGraphDataset(**kwargs) + + # Check if it's a PyG Data object + if _HAS_TORCH_GEOMETRIC: + from torch_geometric.data import Data + + if isinstance(first_item, Data): + return PyGGraphDataset(**kwargs) + + # If we can't detect, raise error + raise ValueError( + f"Cannot auto-detect format for type {type(first_item)}. " + "Please specify format='spektral' or format='pyg' explicitly." + ) + + # For pickle files, require explicit format + elif ( + kwargs.get("pickle_file", None) is not None + or kwargs.get("pickle_folder", None) is not None + ): + if format is None: + raise ValueError( + "When loading from pickle files, you must explicitly specify format='spektral' or format='pyg'" + ) + return _create_dataset(format) + + else: + raise ValueError( + "Must provide either 'graphs', 'pickle_file', or 'pickle_folder'" + )