Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,5 @@ examples/models/*
*.json
diffs/

lightning_logs/
lightning_logs/
tests/files/plot/*
2 changes: 1 addition & 1 deletion tests/test_american_football.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def test_to_pyg_graph(

pyg_graphs = gnnc.to_pytorch_graphs()

dataset = GraphDataset(graphs=pyg_graphs)
dataset = GraphDataset(graphs=pyg_graphs, format="pyg")
N, F, S, n_out, n = dataset.dimensions()
assert N == 23
assert F == len(node_feature_assert_values.keys())
Expand Down
21 changes: 11 additions & 10 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from unravel.utils.objects.graph_dataset import SpektralGraphDataset, PyGGraphDataset


class TestGraphDatasetAutoDetection:
"""Test auto-detection of graph types"""
class TestGraphDatasetFormats:
"""Test graph dataset format specification"""

@pytest.fixture
def spektral_graphs(self):
Expand Down Expand Up @@ -65,18 +65,18 @@ def dict_graphs(self):
return graphs

@pytest.mark.spektral
def test_auto_detect_spektral_graphs(self, spektral_graphs):
"""Test that Spektral graphs are auto-detected"""
def test_spektral_graphs(self, spektral_graphs):
"""Test creating SpektralGraphDataset with format='spektral'"""

dataset = GraphDataset(graphs=spektral_graphs)
dataset = GraphDataset(graphs=spektral_graphs, format="spektral")

assert isinstance(dataset, SpektralGraphDataset)
assert len(dataset) == 10

def test_auto_detect_pyg_graphs(self, pyg_graphs):
"""Test that PyG Data objects are auto-detected"""
def test_pyg_graphs(self, pyg_graphs):
"""Test creating PyGGraphDataset with format='pyg'"""

dataset = GraphDataset(graphs=pyg_graphs)
dataset = GraphDataset(graphs=pyg_graphs, format="pyg")

assert isinstance(dataset, PyGGraphDataset)
assert len(dataset) == 10
Expand Down Expand Up @@ -234,7 +234,7 @@ def pyg_dataset(self):
)
graphs[-1].id = f"graph_{i}"

return GraphDataset(graphs=graphs)
return GraphDataset(graphs=graphs, format="pyg")

@pytest.fixture
def spektral_dataset(self):
Expand Down Expand Up @@ -324,6 +324,7 @@ def test_no_input_raises_error(self):
with pytest.raises(ValueError):
GraphDataset()

@pytest.mark.spektral
def test_unknown_graph_type_raises_error(self):
"""Test that unknown graph type raises an error"""

Expand Down Expand Up @@ -371,5 +372,5 @@ def test_pyg_repr(self):
)
]

dataset = GraphDataset(graphs=graphs)
dataset = GraphDataset(graphs=graphs, format="pyg")
assert repr(dataset) == "PyGGraphDataset(n_graphs=1)"
2 changes: 1 addition & 1 deletion tests/test_soccer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def test_pyg_graph(self, soccer_polars_converter: SoccerGraphConverter):
assert data[0].frame_id == 1524
assert data[-1].frame_id == 2097

dataset = GraphDataset(graphs=pyg_graphs)
dataset = GraphDataset(graphs=pyg_graphs, format="pyg")
N, F, S, n_out, n = dataset.dimensions()
assert N == 20
assert F == 15
Expand Down
2 changes: 1 addition & 1 deletion unravel/soccer/graphs/graph_converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import sys

from dataclasses import dataclass
from dataclasses import dataclass, field

from typing import List, Union, Dict, Literal, Any, Optional, Callable, TYPE_CHECKING

Expand Down
1 change: 0 additions & 1 deletion unravel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
from .objects import *
from .exceptions import *
from .features import *
from .display import *
from .helpers import *
80 changes: 19 additions & 61 deletions unravel/utils/objects/graph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,29 +614,24 @@ def GraphDataset(
format: Optional[Literal["spektral", "pyg"]] = "spektral", **kwargs
) -> Union[SpektralGraphDataset, PyGGraphDataset]:
"""
Factory function that automatically detects and creates the appropriate dataset.
Factory function that creates the appropriate dataset based on format.

Args:
format: Optional format specification ('spektral' or 'pyg').
Only required when passing dict format graphs or pickle files.
For Spektral Graph or PyG Data objects, format is auto-detected.
format: Format specification ('spektral' or 'pyg'). Defaults to 'spektral'.
**kwargs: Arguments passed to the dataset constructor

Returns:
SpektralGraphDataset or PyGGraphDataset depending on format

Examples:
# Auto-detect from Spektral graphs
dataset = GraphDataset(graphs=spektral_graph_list)
# Spektral format (default)
dataset = GraphDataset(graphs=spektral_graph_list, format='spektral')

# Auto-detect from PyG graphs
dataset = GraphDataset(graphs=pyg_data_list)
# PyG format
dataset = GraphDataset(graphs=pyg_data_list, format='pyg')

# Explicit format required for dicts
dataset = GraphDataset(graphs=dict_list, format='pyg')

# Explicit format required for pickle files
dataset = GraphDataset(pickle_file='graphs.pickle.gz', format='spektral')
# From pickle files
dataset = GraphDataset(pickle_file='graphs.pickle.gz', format='pyg')
"""
import warnings

Expand Down Expand Up @@ -665,55 +660,18 @@ def _create_dataset(fmt: str):
else:
raise ValueError(f"format must be 'spektral' or 'pyg', got '{fmt}'")

# Auto-detect from graphs if provided
if kwargs.get("graphs", None) is not None:
graphs = kwargs["graphs"]

if not isinstance(graphs, list) or len(graphs) == 0:
raise ValueError("graphs must be a non-empty list")

first_item = graphs[0]

# Check if it's a dict - require explicit format
if isinstance(first_item, dict):
if format is None:
raise ValueError(
"When passing dict format graphs, you must explicitly specify format='spektral' or format='pyg'"
)
return _create_dataset(format)

# Check if it's a Spektral Graph
if _HAS_SPEKTRAL:
from spektral.data import Graph

if isinstance(first_item, Graph):
return SpektralGraphDataset(**kwargs)

# Check if it's a PyG Data object
if _HAS_TORCH_GEOMETRIC:
from torch_geometric.data import Data

if isinstance(first_item, Data):
return PyGGraphDataset(**kwargs)

# If we can't detect, raise error
raise ValueError(
f"Cannot auto-detect format for type {type(first_item)}. "
"Please specify format='spektral' or format='pyg' explicitly."
)

# For pickle files, require explicit format
elif (
kwargs.get("pickle_file", None) is not None
or kwargs.get("pickle_folder", None) is not None
if (
kwargs.get("graphs") is None
and kwargs.get("pickle_file") is None
and kwargs.get("pickle_folder") is None
):
if format is None:
raise ValueError(
"When loading from pickle files, you must explicitly specify format='spektral' or format='pyg'"
)
return _create_dataset(format)

else:
raise ValueError(
"Must provide either 'graphs', 'pickle_file', or 'pickle_folder'"
)

if kwargs.get("graphs") is not None:
graphs = kwargs["graphs"]
if not isinstance(graphs, list) or len(graphs) == 0:
raise ValueError("graphs must be a non-empty list")

return _create_dataset(format)