Skip to content

Commit 4c30341

Browse files
committed
Conditions refactoring (#758)
1 parent c13901e commit 4c30341

19 files changed

+1593
-618
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Module for managing batches of data with device transfer capabilities.
3+
"""
4+
5+
6+
class _BatchManager(dict):
7+
"""
8+
A dictionary-based batch manager that supports dot-notation
9+
and moving tensors to devices.
10+
"""
11+
12+
def to(self, device):
13+
"""
14+
Move all tensors in the batch to the specified device.
15+
16+
:param device: The target device.
17+
:type device: torch.device | str
18+
:return: The updated batch manager.
19+
:rtype: _BatchManager
20+
"""
21+
for key, value in self.items():
22+
if hasattr(value, "to"):
23+
moved_value = value.to(device)
24+
self[key] = moved_value # Updates both dict and attribute
25+
return self
26+
27+
def __getattribute__(self, name):
28+
"""
29+
Alias attribute access to dictionary keys.
30+
31+
:param str name: The name of the attribute to retrieve.
32+
:return: The value associated with the attribute name.
33+
:rtype: Any
34+
"""
35+
try:
36+
return super().__getattribute__(name)
37+
except AttributeError:
38+
try:
39+
return self[name]
40+
except KeyError:
41+
raise AttributeError(
42+
f"'BatchManager' object has no attribute '{name}'"
43+
)

pina/_src/condition/condition.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ class Condition:
8888
"""
8989

9090
# Combine all possible keyword arguments from the different Condition types
91-
__slots__ = list(
91+
available_kwargs = list(
9292
set(
93-
InputTargetCondition.__slots__
94-
+ InputEquationCondition.__slots__
95-
+ DomainEquationCondition.__slots__
96-
+ DataCondition.__slots__
93+
InputTargetCondition.__fields__
94+
+ InputEquationCondition.__fields__
95+
+ DomainEquationCondition.__fields__
96+
+ DataCondition.__fields__
9797
)
9898
)
9999

@@ -114,28 +114,28 @@ def __new__(cls, *args, **kwargs):
114114
if len(args) != 0:
115115
raise ValueError(
116116
"Condition takes only the following keyword "
117-
f"arguments: {Condition.__slots__}."
117+
f"arguments: {Condition.available_kwargs}."
118118
)
119119

120120
# Class specialization based on keyword arguments
121121
sorted_keys = sorted(kwargs.keys())
122122

123123
# Input - Target Condition
124-
if sorted_keys == sorted(InputTargetCondition.__slots__):
124+
if sorted_keys == sorted(InputTargetCondition.__fields__):
125125
return InputTargetCondition(**kwargs)
126126

127127
# Input - Equation Condition
128-
if sorted_keys == sorted(InputEquationCondition.__slots__):
128+
if sorted_keys == sorted(InputEquationCondition.__fields__):
129129
return InputEquationCondition(**kwargs)
130130

131131
# Domain - Equation Condition
132-
if sorted_keys == sorted(DomainEquationCondition.__slots__):
132+
if sorted_keys == sorted(DomainEquationCondition.__fields__):
133133
return DomainEquationCondition(**kwargs)
134134

135135
# Data Condition
136136
if (
137-
sorted_keys == sorted(DataCondition.__slots__)
138-
or sorted_keys[0] == DataCondition.__slots__[0]
137+
sorted_keys == sorted(DataCondition.__fields__)
138+
or sorted_keys[0] == DataCondition.__fields__[0]
139139
):
140140
return DataCondition(**kwargs)
141141

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Base class for conditions.
3+
"""
4+
5+
from functools import partial
6+
import torch
7+
from torch_geometric.data import Batch
8+
from torch.utils.data import DataLoader
9+
from pina._src.condition.condition_interface import ConditionInterface
10+
from pina._src.core.graph import LabelBatch
11+
from pina._src.core.label_tensor import LabelTensor
12+
13+
14+
class ConditionBase(ConditionInterface):
15+
"""
16+
Base abstract class for all conditions in PINA.
17+
This class provides common functionality for handling data storage,
18+
batching, and interaction with the associated problem.
19+
"""
20+
21+
collate_fn_dict = {
22+
"tensor": torch.stack,
23+
"label_tensor": LabelTensor.stack,
24+
"graph": LabelBatch.from_data_list,
25+
"data": Batch.from_data_list,
26+
}
27+
28+
def __init__(self, **kwargs):
29+
"""
30+
Initialization of the :class:`ConditionBase` class.
31+
32+
:param kwargs: Keyword arguments representing the data to be stored.
33+
"""
34+
super().__init__()
35+
self.data = self.store_data(**kwargs)
36+
37+
@property
38+
def problem(self):
39+
"""
40+
Return the problem associated with this condition.
41+
42+
:return: Problem associated with this condition.
43+
:rtype: ~pina.problem.abstract_problem.AbstractProblem
44+
"""
45+
return self._problem
46+
47+
@problem.setter
48+
def problem(self, value):
49+
"""
50+
Set the problem associated with this condition.
51+
52+
:param pina.problem.abstract_problem.AbstractProblem value: The problem
53+
to associate with this condition
54+
"""
55+
self._problem = value
56+
57+
def __len__(self):
58+
"""
59+
Return the number of data points in the condition.
60+
61+
:return: Number of data points.
62+
:rtype: int
63+
"""
64+
return len(self.data)
65+
66+
def __getitem__(self, idx):
67+
"""
68+
Return the data point(s) at the specified index.
69+
70+
:param idx: Index(es) of the data point(s) to retrieve.
71+
:type idx: int | list[int]
72+
:return: Data point(s) at the specified index.
73+
"""
74+
return self.data[idx]
75+
76+
@classmethod
77+
def automatic_batching_collate_fn(cls, batch):
78+
"""
79+
Collate function for automatic batching to be used in DataLoader.
80+
:param batch: A list of items from the dataset.
81+
:type batch: list
82+
:return: A collated batch.
83+
:rtype: dict
84+
"""
85+
if not batch:
86+
return {}
87+
instance_class = batch[0].__class__
88+
return instance_class.create_batch(batch)
89+
90+
@staticmethod
91+
def collate_fn(batch, condition):
92+
"""
93+
Collate function for custom batching to be used in DataLoader.
94+
95+
:param batch: A list of items from the dataset.
96+
:type batch: list
97+
:param condition: The condition instance.
98+
:type condition: ConditionBase
99+
:return: A collated batch.
100+
:rtype: dict
101+
"""
102+
data = condition.data[batch].to_batch()
103+
return data
104+
105+
def create_dataloader(
106+
self, dataset, batch_size, shuffle, automatic_batching
107+
):
108+
"""
109+
Create a DataLoader for the condition.
110+
111+
:param int batch_size: The batch size for the DataLoader.
112+
:param bool shuffle: Whether to shuffle the data. Default is ``False``.
113+
:return: The DataLoader for the condition.
114+
:rtype: torch.utils.data.DataLoader
115+
"""
116+
if batch_size == len(dataset):
117+
pass # will be updated in the near future
118+
return DataLoader(
119+
dataset=dataset,
120+
batch_size=batch_size,
121+
shuffle=shuffle,
122+
collate_fn=(
123+
partial(self.collate_fn, condition=self)
124+
if not automatic_batching
125+
else self.automatic_batching_collate_fn
126+
),
127+
)
Lines changed: 15 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
"""Module for the Condition interface."""
22

3-
from abc import ABCMeta
4-
from torch_geometric.data import Data
5-
from pina._src.core.label_tensor import LabelTensor
6-
from pina._src.core.graph import Graph
3+
from abc import ABCMeta, abstractmethod
74

85

96
class ConditionInterface(metaclass=ABCMeta):
@@ -15,112 +12,46 @@ class ConditionInterface(metaclass=ABCMeta):
1512
description of all available conditions and how to instantiate them.
1613
"""
1714

18-
def __init__(self):
15+
@abstractmethod
16+
def __init__(self, **kwargs):
1917
"""
2018
Initialization of the :class:`ConditionInterface` class.
2119
"""
22-
self._problem = None
2320

2421
@property
22+
@abstractmethod
2523
def problem(self):
2624
"""
2725
Return the problem associated with this condition.
2826
2927
:return: Problem associated with this condition.
3028
:rtype: ~pina.problem.abstract_problem.AbstractProblem
3129
"""
32-
return self._problem
3330

3431
@problem.setter
32+
@abstractmethod
3533
def problem(self, value):
3634
"""
3735
Set the problem associated with this condition.
3836
3937
:param pina.problem.abstract_problem.AbstractProblem value: The problem
4038
to associate with this condition
4139
"""
42-
self._problem = value
4340

44-
@staticmethod
45-
def _check_graph_list_consistency(data_list):
41+
@abstractmethod
42+
def __len__(self):
4643
"""
47-
Check the consistency of the list of Data | Graph objects.
48-
The following checks are performed:
44+
Return the number of data points in the condition.
4945
50-
- All elements in the list must be of the same type (either
51-
:class:`~torch_geometric.data.Data` or :class:`~pina.graph.Graph`).
52-
53-
- All elements in the list must have the same keys.
54-
55-
- The data type of each tensor must be consistent across all elements.
56-
57-
- If a tensor is a :class:`~pina.label_tensor.LabelTensor`, its labels
58-
must also be consistent across all elements.
59-
60-
:param data_list: The list of Data | Graph objects to check.
61-
:type data_list: list[Data] | list[Graph] | tuple[Data] | tuple[Graph]
62-
:raises ValueError: If the input types are invalid.
63-
:raises ValueError: If all elements in the list do not have the same
64-
keys.
65-
:raises ValueError: If the type of each tensor is not consistent across
66-
all elements in the list.
67-
:raises ValueError: If the labels of the LabelTensors are not consistent
68-
across all elements in the list.
46+
:return: Number of data points.
47+
:rtype: int
6948
"""
70-
# If the data is a Graph or Data object, perform no checks
71-
if isinstance(data_list, (Graph, Data)):
72-
return
73-
74-
# Check all elements in the list are of the same type
75-
if not all(isinstance(i, (Graph, Data)) for i in data_list):
76-
raise ValueError(
77-
"Invalid input. Please, provide either Data or Graph objects."
78-
)
79-
80-
# Store the keys, data types and labels of the first element
81-
data = data_list[0]
82-
keys = sorted(list(data.keys()))
83-
data_types = {name: tensor.__class__ for name, tensor in data.items()}
84-
labels = {
85-
name: tensor.labels
86-
for name, tensor in data.items()
87-
if isinstance(tensor, LabelTensor)
88-
}
89-
90-
# Iterate over the list of Data | Graph objects
91-
for data in data_list[1:]:
92-
93-
# Check that all elements in the list have the same keys
94-
if sorted(list(data.keys())) != keys:
95-
raise ValueError(
96-
"All elements in the list must have the same keys."
97-
)
98-
99-
# Iterate over the tensors in the current element
100-
for name, tensor in data.items():
101-
# Check that the type of each tensor is consistent
102-
if tensor.__class__ is not data_types[name]:
103-
raise ValueError(
104-
f"Data {name} must be a {data_types[name]}, got "
105-
f"{tensor.__class__}"
106-
)
107-
108-
# Check that the labels of each LabelTensor are consistent
109-
if isinstance(tensor, LabelTensor):
110-
if tensor.labels != labels[name]:
111-
raise ValueError(
112-
"LabelTensor must have the same labels"
113-
)
11449

115-
def __getattribute__(self, name):
50+
@abstractmethod
51+
def __getitem__(self, idx):
11652
"""
117-
Get an attribute from the object.
53+
Return the data point(s) at the specified index.
11854
119-
:param str name: The name of the attribute to get.
120-
:return: The requested attribute.
121-
:rtype: Any
55+
:param int idx: Index of the data point(s) to retrieve.
56+
:return: Data point(s) at the specified index.
12257
"""
123-
to_return = super().__getattribute__(name)
124-
if isinstance(to_return, (Graph, Data)):
125-
to_return = [to_return]
126-
return to_return

0 commit comments

Comments
 (0)