diff --git a/.gitignore b/.gitignore index 577c8727..2abc593f 100644 --- a/.gitignore +++ b/.gitignore @@ -197,4 +197,5 @@ examples/models/* *.json diffs/ -lightning_logs/ \ No newline at end of file +lightning_logs/ +tests/files/plot/* \ No newline at end of file diff --git a/tests/test_american_football.py b/tests/test_american_football.py index 99dc3d53..e6dc4101 100644 --- a/tests/test_american_football.py +++ b/tests/test_american_football.py @@ -416,7 +416,7 @@ def test_to_pyg_graph( pyg_graphs = gnnc.to_pytorch_graphs() - dataset = GraphDataset(graphs=pyg_graphs) + dataset = GraphDataset(graphs=pyg_graphs, format="pyg") N, F, S, n_out, n = dataset.dimensions() assert N == 23 assert F == len(node_feature_assert_values.keys()) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 48262632..7d91bc2d 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -7,8 +7,8 @@ from unravel.utils.objects.graph_dataset import SpektralGraphDataset, PyGGraphDataset -class TestGraphDatasetAutoDetection: - """Test auto-detection of graph types""" +class TestGraphDatasetFormats: + """Test graph dataset format specification""" @pytest.fixture def spektral_graphs(self): @@ -65,18 +65,18 @@ def dict_graphs(self): return graphs @pytest.mark.spektral - def test_auto_detect_spektral_graphs(self, spektral_graphs): - """Test that Spektral graphs are auto-detected""" + def test_spektral_graphs(self, spektral_graphs): + """Test creating SpektralGraphDataset with format='spektral'""" - dataset = GraphDataset(graphs=spektral_graphs) + dataset = GraphDataset(graphs=spektral_graphs, format="spektral") assert isinstance(dataset, SpektralGraphDataset) assert len(dataset) == 10 - def test_auto_detect_pyg_graphs(self, pyg_graphs): - """Test that PyG Data objects are auto-detected""" + def test_pyg_graphs(self, pyg_graphs): + """Test creating PyGGraphDataset with format='pyg'""" - dataset = GraphDataset(graphs=pyg_graphs) + dataset = GraphDataset(graphs=pyg_graphs, format="pyg") assert isinstance(dataset, PyGGraphDataset) assert len(dataset) == 10 @@ -234,7 +234,7 @@ def pyg_dataset(self): ) graphs[-1].id = f"graph_{i}" - return GraphDataset(graphs=graphs) + return GraphDataset(graphs=graphs, format="pyg") @pytest.fixture def spektral_dataset(self): @@ -324,6 +324,7 @@ def test_no_input_raises_error(self): with pytest.raises(ValueError): GraphDataset() + @pytest.mark.spektral def test_unknown_graph_type_raises_error(self): """Test that unknown graph type raises an error""" @@ -371,5 +372,5 @@ def test_pyg_repr(self): ) ] - dataset = GraphDataset(graphs=graphs) + dataset = GraphDataset(graphs=graphs, format="pyg") assert repr(dataset) == "PyGGraphDataset(n_graphs=1)" diff --git a/tests/test_soccer.py b/tests/test_soccer.py index e2f849d0..497e05a0 100644 --- a/tests/test_soccer.py +++ b/tests/test_soccer.py @@ -1006,7 +1006,7 @@ def test_pyg_graph(self, soccer_polars_converter: SoccerGraphConverter): assert data[0].frame_id == 1524 assert data[-1].frame_id == 2097 - dataset = GraphDataset(graphs=pyg_graphs) + dataset = GraphDataset(graphs=pyg_graphs, format="pyg") N, F, S, n_out, n = dataset.dimensions() assert N == 20 assert F == 15 diff --git a/unravel/soccer/graphs/graph_converter.py b/unravel/soccer/graphs/graph_converter.py index 47aa2ac6..36372100 100644 --- a/unravel/soccer/graphs/graph_converter.py +++ b/unravel/soccer/graphs/graph_converter.py @@ -1,7 +1,7 @@ import logging import sys -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Union, Dict, Literal, Any, Optional, Callable, TYPE_CHECKING diff --git a/unravel/utils/__init__.py b/unravel/utils/__init__.py index d3ea75e8..db627c33 100644 --- a/unravel/utils/__init__.py +++ b/unravel/utils/__init__.py @@ -2,5 +2,4 @@ from .objects import * from .exceptions import * from .features import * -from .display import * from .helpers import * diff --git a/unravel/utils/objects/graph_dataset.py b/unravel/utils/objects/graph_dataset.py index 90d438bf..9d88f4db 100644 --- a/unravel/utils/objects/graph_dataset.py +++ b/unravel/utils/objects/graph_dataset.py @@ -614,29 +614,24 @@ def GraphDataset( format: Optional[Literal["spektral", "pyg"]] = "spektral", **kwargs ) -> Union[SpektralGraphDataset, PyGGraphDataset]: """ - Factory function that automatically detects and creates the appropriate dataset. + Factory function that creates the appropriate dataset based on format. Args: - format: Optional format specification ('spektral' or 'pyg'). - Only required when passing dict format graphs or pickle files. - For Spektral Graph or PyG Data objects, format is auto-detected. + format: Format specification ('spektral' or 'pyg'). Defaults to 'spektral'. **kwargs: Arguments passed to the dataset constructor Returns: SpektralGraphDataset or PyGGraphDataset depending on format Examples: - # Auto-detect from Spektral graphs - dataset = GraphDataset(graphs=spektral_graph_list) + # Spektral format (default) + dataset = GraphDataset(graphs=spektral_graph_list, format='spektral') - # Auto-detect from PyG graphs - dataset = GraphDataset(graphs=pyg_data_list) + # PyG format + dataset = GraphDataset(graphs=pyg_data_list, format='pyg') - # Explicit format required for dicts - dataset = GraphDataset(graphs=dict_list, format='pyg') - - # Explicit format required for pickle files - dataset = GraphDataset(pickle_file='graphs.pickle.gz', format='spektral') + # From pickle files + dataset = GraphDataset(pickle_file='graphs.pickle.gz', format='pyg') """ import warnings @@ -665,55 +660,18 @@ def _create_dataset(fmt: str): else: raise ValueError(f"format must be 'spektral' or 'pyg', got '{fmt}'") - # Auto-detect from graphs if provided - if kwargs.get("graphs", None) is not None: - graphs = kwargs["graphs"] - - if not isinstance(graphs, list) or len(graphs) == 0: - raise ValueError("graphs must be a non-empty list") - - first_item = graphs[0] - - # Check if it's a dict - require explicit format - if isinstance(first_item, dict): - if format is None: - raise ValueError( - "When passing dict format graphs, you must explicitly specify format='spektral' or format='pyg'" - ) - return _create_dataset(format) - - # Check if it's a Spektral Graph - if _HAS_SPEKTRAL: - from spektral.data import Graph - - if isinstance(first_item, Graph): - return SpektralGraphDataset(**kwargs) - - # Check if it's a PyG Data object - if _HAS_TORCH_GEOMETRIC: - from torch_geometric.data import Data - - if isinstance(first_item, Data): - return PyGGraphDataset(**kwargs) - - # If we can't detect, raise error - raise ValueError( - f"Cannot auto-detect format for type {type(first_item)}. " - "Please specify format='spektral' or format='pyg' explicitly." - ) - - # For pickle files, require explicit format - elif ( - kwargs.get("pickle_file", None) is not None - or kwargs.get("pickle_folder", None) is not None + if ( + kwargs.get("graphs") is None + and kwargs.get("pickle_file") is None + and kwargs.get("pickle_folder") is None ): - if format is None: - raise ValueError( - "When loading from pickle files, you must explicitly specify format='spektral' or format='pyg'" - ) - return _create_dataset(format) - - else: raise ValueError( "Must provide either 'graphs', 'pickle_file', or 'pickle_folder'" ) + + if kwargs.get("graphs") is not None: + graphs = kwargs["graphs"] + if not isinstance(graphs, list) or len(graphs) == 0: + raise ValueError("graphs must be a non-empty list") + + return _create_dataset(format)