Skip to content

Commit a6af8d4

Browse files
codacy
1 parent b6ce22f commit a6af8d4

File tree

11 files changed

+241
-117
lines changed

11 files changed

+241
-117
lines changed

pina/domain/base_domain.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
"""Module for the Base Domain class."""
1+
"""Module for the Base class for domains."""
22

33
from copy import deepcopy
4+
from abc import ABCMeta
45
from .domain_interface import DomainInterface
56
from ..utils import check_consistency, check_positive_integer
67

78

8-
class BaseDomain(DomainInterface):
9+
class BaseDomain(DomainInterface, metaclass=ABCMeta):
910
"""
1011
Base class for all geometric domains, implementing common functionality.
1112
@@ -51,9 +52,9 @@ def __init__(self, domain_dict):
5152
check_consistency(v, (int, float))
5253

5354
# Store
54-
self._fixed = dict()
55-
self._range = dict()
56-
invalid = list()
55+
self._fixed = {}
56+
self._range = {}
57+
invalid = []
5758

5859
# Iterate over domain_dict items
5960
for k, v in domain_dict.items():
@@ -70,8 +71,7 @@ def __init__(self, domain_dict):
7071
f"Invalid range for variable '{k}': "
7172
f"low ({low}) >= high ({high})"
7273
)
73-
else:
74-
self._range[k] = (low, high)
74+
self._range[k] = (low, high)
7575

7676
# Save invalid keys
7777
else:
@@ -106,8 +106,8 @@ def update(self, domain):
106106

107107
# Update fixed and ranged variables
108108
updated = deepcopy(self)
109-
updated._fixed.update(domain._fixed)
110-
updated._range.update(domain._range)
109+
updated.fixed.update(domain.fixed)
110+
updated.range.update(domain.range)
111111

112112
return updated
113113

@@ -207,3 +207,23 @@ def domain_dict(self):
207207
:rtype: dict
208208
"""
209209
return {**self._fixed, **self._range}
210+
211+
@property
212+
def range(self):
213+
"""
214+
The range variables of the domain.
215+
216+
:return: The range variables of the domain.
217+
:rtype: dict
218+
"""
219+
return self._range
220+
221+
@property
222+
def fixed(self):
223+
"""
224+
The fixed variables of the domain.
225+
226+
:return: The fixed variables of the domain.
227+
:rtype: dict
228+
"""
229+
return self._fixed

pina/domain/base_operation.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
"""Module for the Base Operation class."""
1+
"""Module for the Base class for all set-operations."""
22

33
from copy import deepcopy
4+
from abc import ABCMeta
45
from .operation_interface import OperationInterface
56
from .base_domain import BaseDomain
67
from ..utils import check_consistency, check_positive_integer
78

89

9-
class BaseOperation(OperationInterface):
10+
class BaseOperation(OperationInterface, metaclass=ABCMeta):
1011
"""
1112
Base class for all set operation defined on geometric domains, implementing
1213
common functionality.
@@ -172,6 +173,26 @@ def geometries(self):
172173
"""
173174
return self._geometries
174175

176+
@property
177+
def range(self):
178+
"""
179+
The range variables of each geometry.
180+
181+
:return: The range variables of each geometry.
182+
:rtype: dict
183+
"""
184+
return {f"geometry_{i}": g.range for i, g in enumerate(self.geometries)}
185+
186+
@property
187+
def fixed(self):
188+
"""
189+
The fixed variables of each geometry.
190+
191+
:return: The fixed variables of each geometry.
192+
:rtype: dict
193+
"""
194+
return {f"geometry_{i}": g.fixed for i, g in enumerate(self.geometries)}
195+
175196
@geometries.setter
176197
def geometries(self, values):
177198
"""

pina/domain/cartesian_domain.py

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ def is_inside(self, point, check_border=False):
4141

4242
# Fixed variable checks
4343
fixed_check = all(
44-
bool((point.extract([k]) == v).all())
45-
for k, v in self._fixed.items()
46-
if k in point.labels
44+
(point.extract([k]) == v).all() for k, v in self._fixed.items()
4745
)
4846

4947
# If there are no range variables, return fixed variable check
@@ -53,27 +51,17 @@ def is_inside(self, point, check_border=False):
5351
# Ranged variable checks -- check_border True
5452
if check_border:
5553
range_check = all(
56-
bool(
57-
(
58-
(point.extract([k]) >= self._range[k][0])
59-
& (point.extract([k]) <= self._range[k][1])
60-
).all()
61-
)
62-
for k in self._range
63-
if k in point.labels
54+
(
55+
(point.extract([k]) >= low) & (point.extract([k]) <= high)
56+
).all()
57+
for k, (low, high) in self._range.items()
6458
)
6559

6660
# Ranged variable checks -- check_border False
6761
else:
6862
range_check = all(
69-
bool(
70-
(
71-
(point.extract([k]) > self._range[k][0])
72-
& (point.extract([k]) < self._range[k][1])
73-
).all()
74-
)
75-
for k in self._range
76-
if k in point.labels
63+
((point.extract([k]) > low) & (point.extract([k]) < high)).all()
64+
for k, (low, high) in self._range.items()
7765
)
7866

7967
return fixed_check and range_check
@@ -213,29 +201,17 @@ def partial(self):
213201
:return: The boundary of the domain.
214202
:rtype: Union
215203
"""
216-
faces = list()
204+
faces = []
217205

218206
# Iterate over ranged variables
219207
for var, (low, high) in self._range.items():
220208

221209
# Fix the variable to its low value to get the lower face
222-
low_face = CartesianDomain(
223-
{
224-
**self._fixed,
225-
**{k: v for k, v in self._range.items()},
226-
var: low,
227-
}
228-
)
210+
lower = CartesianDomain({**self._fixed, **self._range, var: low})
229211

230212
# Fix the variable to its high value to get the upper face
231-
high_face = CartesianDomain(
232-
{
233-
**self._fixed,
234-
**{k: v for k, v in self._range.items()},
235-
var: high,
236-
}
237-
)
213+
higher = CartesianDomain({**self._fixed, **self._range, var: high})
238214

239-
faces.extend([low_face, high_face])
215+
faces.extend([lower, higher])
240216

241217
return Union(faces)

pina/domain/domain_interface.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,23 @@ def domain_dict(self):
9393
:return: The dictionary representing the domain.
9494
:rtype: dict
9595
"""
96+
97+
@property
98+
@abstractmethod
99+
def range(self):
100+
"""
101+
The range variables of the domain.
102+
103+
:return: The range variables of the domain.
104+
:rtype: dict
105+
"""
106+
107+
@property
108+
@abstractmethod
109+
def fixed(self):
110+
"""
111+
The fixed variables of the domain.
112+
113+
:return: The fixed variables of the domain.
114+
:rtype: dict
115+
"""

pina/domain/ellipsoid_domain.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Module for the Ellipsoid Domain."""
22

3-
import torch
43
from copy import deepcopy
4+
import torch
55
from .base_domain import BaseDomain
66
from ..label_tensor import LabelTensor
77
from ..utils import check_consistency
@@ -43,16 +43,13 @@ def __init__(self, ellipsoid_dict, sample_surface=False):
4343
:raises ValueError: If the ellipsoid dictionary contains values that are
4444
neither numbers nor lists/tuples of numbers of length 2.
4545
"""
46-
# Check consistency
47-
check_consistency(sample_surface, bool)
48-
4946
# Initialization
5047
super().__init__(domain_dict=ellipsoid_dict)
51-
self._sample_surface = sample_surface
48+
self.sample_surface = sample_surface
5249
self.sample_modes = ["random"]
53-
self._compute_center_axes()
50+
self.compute_center_axes()
5451

55-
def _compute_center_axes(self):
52+
def compute_center_axes(self):
5653
"""
5754
Compute centers and axes for the ellipsoid.
5855
"""
@@ -93,9 +90,7 @@ def is_inside(self, point, check_border=False):
9390

9491
# Fixed variable checks
9592
fixed_check = all(
96-
bool((point.extract([k]) == v).all())
97-
for k, v in self._fixed.items()
98-
if k in point.labels
93+
(point.extract([k]) == v).all() for k, v in self._fixed.items()
9994
)
10095

10196
# If there are no range variables, return fixed variable check
@@ -116,7 +111,7 @@ def is_inside(self, point, check_border=False):
116111
# Range variable check in the volume
117112
range_check = (eqn <= 0) if check_border else (eqn < 0)
118113

119-
return fixed_check and bool(range_check.item())
114+
return fixed_check and range_check.item()
120115

121116
def update(self, domain):
122117
"""
@@ -132,7 +127,7 @@ def update(self, domain):
132127
:rtype: EllipsoidDomain
133128
"""
134129
updated = super().update(domain)
135-
updated._compute_center_axes()
130+
updated.compute_center_axes()
136131

137132
return updated
138133

@@ -180,7 +175,7 @@ def sample(self, n, mode="random", variables="all"):
180175
return result
181176

182177
# Sample points
183-
pts = self._sample_range(n, mode, range_vars)
178+
pts = self._sample_range(n, range_vars)
184179
labels = range_vars
185180

186181
# Add fixed vars
@@ -198,12 +193,11 @@ def sample(self, n, mode="random", variables="all"):
198193

199194
return pts[sorted(pts.labels)]
200195

201-
def _sample_range(self, n, mode, variables):
196+
def _sample_range(self, n, variables):
202197
"""
203198
Sample points and rescale to fit within the specified bounds.
204199
205200
:param int n: The number of points to sample.
206-
:param str mode: The sampling method. Default is ``random``.
207201
:param list[str] variables: variables whose samples must be rescaled.
208202
:return: The rescaled sample points.
209203
:rtype: torch.Tensor
@@ -242,6 +236,29 @@ def partial(self):
242236
:rtype: EllipsoidDomain
243237
"""
244238
boundary = deepcopy(self)
245-
boundary._sample_surface = True
239+
boundary.sample_surface = True
246240

247241
return boundary
242+
243+
@property
244+
def sample_surface(self):
245+
"""
246+
Whether only the surface of the ellipsoid is considered part of the
247+
domain.
248+
249+
:return: ``True`` if only the surface is considered part of the domain,
250+
``False`` otherwise.
251+
:rtype: bool
252+
"""
253+
return self._sample_surface
254+
255+
@sample_surface.setter
256+
def sample_surface(self, value):
257+
"""
258+
Setter for the sample_surface property.
259+
260+
:param bool value: The new value for the sample_surface property.
261+
:raises ValueError: If ``value`` is not a boolean.
262+
"""
263+
check_consistency(value, bool)
264+
self._sample_surface = value

pina/domain/exclusion.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Module for the Exclusion operation."""
1+
"""Module for the Exclusion set-operation."""
22

33
import random
44
from .base_operation import BaseOperation
@@ -87,23 +87,22 @@ def sample(self, n, mode="random", variables="all"):
8787
# Validate sampling settings
8888
variables = self._validate_sampling(n, mode, variables)
8989

90-
# Save the number of geometries
91-
n_geometries = len(self.geometries)
92-
9390
# Compute number of points per geometry and remainder
94-
num_points, remainder = divmod(n, n_geometries)
91+
num_pts, remainder = divmod(n, len(self.geometries))
9592

9693
# Shuffle indices
97-
shuffled_geometries = random.sample(range(n_geometries), n_geometries)
94+
shuffled_geometries = random.sample(
95+
range(len(self.geometries)), len(self.geometries)
96+
)
9897

9998
# Precompute per-geometry allocations following the shuffled order
100-
alloc = [num_points + (i < remainder) for i in range(n_geometries)]
99+
alloc = [num_pts + (i < remainder) for i in range(len(self.geometries))]
101100
samples = []
102101

103102
# Iterate over geometries in shuffled order
104103
for idx, gi in enumerate(shuffled_geometries):
105104

106-
# Skip if no points to allocate (possible if n_geometries > n)
105+
# If no points to allocate (possible if len(self.geometries) > n)
107106
if alloc[idx] == 0:
108107
continue
109108

pina/domain/intersection.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,23 +77,22 @@ def sample(self, n, mode="random", variables="all"):
7777
# Validate sampling settings
7878
variables = self._validate_sampling(n, mode, variables)
7979

80-
# Save the number of geometries
81-
n_geometries = len(self.geometries)
82-
8380
# Compute number of points per geometry and remainder
84-
num_points, remainder = divmod(n, n_geometries)
81+
num_pts, remainder = divmod(n, len(self.geometries))
8582

8683
# Shuffle indices
87-
shuffled_geometries = random.sample(range(n_geometries), n_geometries)
84+
shuffled_geometries = random.sample(
85+
range(len(self.geometries)), len(self.geometries)
86+
)
8887

8988
# Precompute per-geometry allocations following the shuffled order
90-
alloc = [num_points + (i < remainder) for i in range(n_geometries)]
89+
alloc = [num_pts + (i < remainder) for i in range(len(self.geometries))]
9190
samples = []
9291

9392
# Iterate over geometries in shuffled order
9493
for idx, gi in enumerate(shuffled_geometries):
9594

96-
# Skip if no points to allocate (possible if n_geometries > n)
95+
# If no points to allocate (possible if len(self.geometries) > n)
9796
if alloc[idx] == 0:
9897
continue
9998

0 commit comments

Comments
 (0)