Skip to content

Commit f7be49b

Browse files
authored
Merge pull request #219 from firedrakeproject/pbrubeck/merge-release-to-main
2 parents b27dc5c + c860b62 commit f7be49b

File tree

3 files changed

+128
-68
lines changed

3 files changed

+128
-68
lines changed

finat/ufl/finiteelementbase.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@
2020
from ufl.utils.sequences import product
2121

2222

23+
# Dict of supported pullback names and their ufl representation
24+
supported_pullbacks = {
25+
"identity": pullback.identity_pullback,
26+
"L2 Piola": pullback.l2_piola,
27+
"covariant Piola": pullback.covariant_piola,
28+
"contravariant Piola": pullback.contravariant_piola,
29+
"double covariant Piola": pullback.double_covariant_piola,
30+
"double contravariant Piola": pullback.double_contravariant_piola,
31+
"covariant contravariant Piola": pullback.covariant_contravariant_piola,
32+
"custom": pullback.custom_pullback,
33+
"physical": pullback.physical_pullback,
34+
}
35+
36+
2337
class FiniteElementBase(AbstractFiniteElement):
2438
"""Base class for all finite elements."""
2539
__slots__ = ("_family", "_cell", "_degree", "_quad_scheme",
@@ -121,6 +135,14 @@ def is_cellwise_constant(self, component=None):
121135
"""Return whether the basis functions of this element is spatially constant over each cell."""
122136
return self._is_globally_constant() or self.degree() == 0
123137

138+
def value_shape(self, domain=None):
139+
"""Return the shape of the value space on a physical domain."""
140+
return self.pullback.physical_value_shape(self, domain)
141+
142+
def value_size(self, domain=None):
143+
"""Return the integer product of the value shape on a physical domain."""
144+
return product(self.value_shape(domain))
145+
124146
@property
125147
def reference_value_shape(self):
126148
"""Return the shape of the value space on the reference cell."""
@@ -131,7 +153,7 @@ def reference_value_size(self):
131153
"""Return the integer product of the reference value shape."""
132154
return product(self.reference_value_shape)
133155

134-
def symmetry(self): # FIXME: different approach
156+
def symmetry(self, domain=None):
135157
r"""Return the symmetry dict.
136158
137159
This is a mapping :math:`c_0 \\to c_1`
@@ -141,37 +163,37 @@ def symmetry(self): # FIXME: different approach
141163
"""
142164
return {}
143165

144-
def _check_component(self, domain, i):
166+
def _check_component(self, i, domain=None):
145167
"""Check that component index i is valid."""
146-
sh = self.value_shape(domain.geometric_dimension())
168+
sh = self.value_shape(domain)
147169
r = len(sh)
148-
if not (len(i) == r and all(j < k for (j, k) in zip(i, sh))):
170+
if not (len(i) == r and all(int(j) < k for (j, k) in zip(i, sh))):
149171
raise ValueError(
150172
f"Illegal component index {i} (value rank {len(i)}) "
151173
f"for element (value rank {r}).")
152174

153-
def extract_subelement_component(self, domain, i):
175+
def extract_subelement_component(self, i, domain=None):
154176
"""Extract direct subelement index and subelement relative component index for a given component index."""
155177
if isinstance(i, int):
156178
i = (i,)
157-
self._check_component(domain, i)
179+
self._check_component(i, domain)
158180
return (None, i)
159181

160-
def extract_component(self, domain, i):
182+
def extract_component(self, i, domain=None):
161183
"""Recursively extract component index relative to a (simple) element.
162184
163185
and that element for given value component index.
164186
"""
165187
if isinstance(i, int):
166188
i = (i,)
167-
self._check_component(domain, i)
189+
self._check_component(i, domain)
168190
return (i, self)
169191

170192
def _check_reference_component(self, i):
171193
"""Check that reference component index i is valid."""
172194
sh = self.reference_value_shape
173195
r = len(sh)
174-
if not (len(i) == r and all(j < k for (j, k) in zip(i, sh))):
196+
if not (len(i) == r and all(int(j) < k for (j, k) in zip(i, sh))):
175197
raise ValueError(
176198
f"Illegal component index {i} (value rank {len(i)}) "
177199
f"for element (value rank {r}).")
@@ -246,23 +268,7 @@ def embedded_subdegree(self):
246268
@property
247269
def pullback(self):
248270
"""Get the pull back."""
249-
if self.mapping() == "identity":
250-
return pullback.identity_pullback
251-
elif self.mapping() == "L2 Piola":
252-
return pullback.l2_piola
253-
elif self.mapping() == "covariant Piola":
254-
return pullback.covariant_piola
255-
elif self.mapping() == "contravariant Piola":
256-
return pullback.contravariant_piola
257-
elif self.mapping() == "double covariant Piola":
258-
return pullback.double_covariant_piola
259-
elif self.mapping() == "double contravariant Piola":
260-
return pullback.double_contravariant_piola
261-
elif self.mapping() == "covariant contravariant Piola":
262-
return pullback.covariant_contravariant_piola
263-
elif self.mapping() == "custom":
264-
return pullback.custom_pullback
265-
elif self.mapping() == "physical":
266-
return pullback.physical_pullback
267-
268-
raise ValueError(f"Unsupported mapping: {self.mapping()}")
271+
try:
272+
return supported_pullbacks[self.mapping()]
273+
except KeyError:
274+
raise ValueError(f"Unsupported mapping: {self.mapping()}")

finat/ufl/mixedelement.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self, *elements, **kwargs):
4949
if not all(e.quadrature_scheme() == quad_scheme for e in elements):
5050
raise ValueError("Quadrature scheme mismatch for sub elements of mixed element.")
5151

52-
# Compute value sizes in global and reference configurations
52+
# Compute value sizes in reference configuration
5353
reference_value_size_sum = sum(product(s.reference_value_shape) for s in self._sub_elements)
5454

5555
# Default reference value shape: Treated simply as all
@@ -71,7 +71,7 @@ def _make_cell(self):
7171

7272
def __repr__(self):
7373
"""Doc."""
74-
return "MixedElement(" + ", ".join(repr(e) for e in self._sub_elements) + ")"
74+
return "MixedElement(" + ", ".join(map(repr, self._sub_elements)) + ")"
7575

7676
def _is_linear(self):
7777
"""Doc."""
@@ -83,7 +83,7 @@ def reconstruct_from_elements(self, *elements):
8383
return self
8484
return MixedElement(*elements)
8585

86-
def symmetry(self, domain):
86+
def symmetry(self, domain=None):
8787
r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`.
8888
8989
meaning that component :math:`c_0` is represented by component
@@ -101,15 +101,15 @@ def symmetry(self, domain):
101101
st = shape_to_strides(sh)
102102
# Map symmetries of subelement into index space of this
103103
# element
104-
for c0, c1 in e.symmetry().items():
104+
for c0, c1 in e.symmetry(domain).items():
105105
j0 = flatten_multiindex(c0, st) + j
106106
j1 = flatten_multiindex(c1, st) + j
107107
sm[(j0,)] = (j1,)
108108
# Update base index for next element
109109
j += product(sh)
110110
if j != product(self.value_shape(domain)):
111111
raise ValueError("Size mismatch in symmetry algorithm.")
112-
return sm or {}
112+
return sm
113113

114114
@property
115115
def sobolev_space(self):
@@ -133,7 +133,7 @@ def sub_elements(self):
133133
"""Return list of sub elements."""
134134
return self._sub_elements
135135

136-
def extract_subelement_component(self, domain, i):
136+
def extract_subelement_component(self, i, domain=None):
137137
"""Extract direct subelement index and subelement relative.
138138
139139
component index for a given component index.
@@ -142,13 +142,14 @@ def extract_subelement_component(self, domain, i):
142142
raise NotImplementedError
143143
if isinstance(i, int):
144144
i = (i,)
145-
self._check_component(i)
145+
self._check_component(i, domain)
146146

147147
# Select between indexing modes
148148
if len(self.value_shape(domain)) == 1:
149149
# Indexing into a long vector of flattened subelement
150150
# shapes
151151
j, = i
152+
j = int(j)
152153

153154
# Find subelement for this index
154155
for sub_element_index, e in enumerate(self._sub_elements):
@@ -172,13 +173,13 @@ def extract_subelement_component(self, domain, i):
172173
component = i[1:]
173174
return (sub_element_index, component)
174175

175-
def extract_component(self, i):
176+
def extract_component(self, i, domain=None):
176177
"""Recursively extract component index relative to a (simple) element.
177178
178179
and that element for given value component index.
179180
"""
180-
sub_element_index, component = self.extract_subelement_component(i)
181-
return self._sub_elements[sub_element_index].extract_component(component)
181+
sub_element_index, component = self.extract_subelement_component(i, domain)
182+
return self._sub_elements[sub_element_index].extract_component(component, domain)
182183

183184
def extract_subelement_reference_component(self, i):
184185
"""Extract direct subelement index and subelement relative.
@@ -193,6 +194,7 @@ def extract_subelement_reference_component(self, i):
193194
assert len(self.reference_value_shape) == 1
194195
# Indexing into a long vector of flattened subelement shapes
195196
j, = i
197+
j = int(j)
196198

197199
# Find subelement for this index
198200
for sub_element_index, e in enumerate(self._sub_elements):
@@ -217,20 +219,20 @@ def extract_reference_component(self, i):
217219
sub_element_index, reference_component = self.extract_subelement_reference_component(i)
218220
return self._sub_elements[sub_element_index].extract_reference_component(reference_component)
219221

220-
def is_cellwise_constant(self, component=None):
222+
def is_cellwise_constant(self, component=None, domain=None):
221223
"""Return whether the basis functions of this element is spatially constant over each cell."""
222224
if component is None:
223225
return all(e.is_cellwise_constant() for e in self.sub_elements)
224226
else:
225-
i, e = self.extract_component(component)
227+
i, e = self.extract_component(component, domain)
226228
return e.is_cellwise_constant()
227229

228-
def degree(self, component=None):
230+
def degree(self, component=None, domain=None):
229231
"""Return polynomial degree of finite element."""
230232
if component is None:
231233
return self._degree # from FiniteElementBase, computed as max of subelements in __init__
232234
else:
233-
i, e = self.extract_component(component)
235+
i, e = self.extract_component(component, domain)
234236
return e.degree()
235237

236238
@property
@@ -244,7 +246,7 @@ def embedded_superdegree(self):
244246
return max(e.embedded_superdegree for e in self.sub_elements)
245247

246248
def reconstruct(self, **kwargs):
247-
"""Doc."""
249+
"""Construct a new FiniteElement object with some properties replaced with new values."""
248250
cell = kwargs.pop('cell', None)
249251
if cell is None:
250252
cell = self.cell
@@ -257,7 +259,7 @@ def reconstruct(self, **kwargs):
257259
)
258260

259261
def variant(self):
260-
"""Doc."""
262+
"""Return the common variant to all subelements."""
261263
try:
262264
variant, = {e.variant() for e in self.sub_elements}
263265
return variant
@@ -266,7 +268,7 @@ def variant(self):
266268

267269
def __str__(self):
268270
"""Format as string for pretty printing."""
269-
tmp = ", ".join(str(element) for element in self._sub_elements)
271+
tmp = ", ".join(map(str, self._sub_elements))
270272
return "<Mixed element: (" + tmp + ")>"
271273

272274
def shortstr(self):
@@ -291,7 +293,6 @@ def __init__(self, family, cell=None, degree=None, dim=None,
291293
if isinstance(family, FiniteElementBase):
292294
sub_element = family
293295
cell = sub_element.cell
294-
variant = sub_element.variant()
295296
else:
296297
if cell is not None:
297298
cell = as_cell(cell)
@@ -326,13 +327,8 @@ def __init__(self, family, cell=None, degree=None, dim=None,
326327

327328
self._sub_element = sub_element
328329

329-
if variant is None:
330-
var_str = ""
331-
else:
332-
var_str = ", variant='" + variant + "'"
333-
334330
# Cache repr string
335-
self._repr = f"VectorElement({repr(sub_element)}, dim={dim}{var_str})"
331+
self._repr = f"VectorElement({repr(sub_element)}, dim={dim})"
336332

337333
def _make_cell(self):
338334
if self.num_sub_elements == 0:
@@ -388,7 +384,6 @@ def __init__(self, family, cell=None, degree=None, shape=None,
388384
if isinstance(family, FiniteElementBase):
389385
sub_element = family
390386
cell = sub_element.cell
391-
variant = sub_element.variant()
392387
else:
393388
if cell is not None:
394389
cell = as_cell(cell)
@@ -435,7 +430,7 @@ def __init__(self, family, cell=None, degree=None, shape=None,
435430
if index in symmetry:
436431
continue
437432
sub_element_mapping[index] = len(sub_elements)
438-
sub_elements += [sub_element]
433+
sub_elements.append(sub_element)
439434

440435
# Update mapping for symmetry
441436
for index in indices:
@@ -466,14 +461,9 @@ def __init__(self, family, cell=None, degree=None, shape=None,
466461
self._sub_element_mapping = sub_element_mapping
467462
self._flattened_sub_element_mapping = flattened_sub_element_mapping
468463

469-
if variant is None:
470-
var_str = ""
471-
else:
472-
var_str = ", variant='" + variant + "'"
473-
474464
# Cache repr string
475465
self._repr = (f"TensorElement({repr(sub_element)}, shape={shape}, "
476-
f"symmetry={symmetry}{var_str})")
466+
f"symmetry={symmetry})")
477467

478468
def _make_cell(self):
479469
if self.num_sub_elements == 0:
@@ -518,25 +508,25 @@ def flattened_sub_element_mapping(self):
518508
"""Doc."""
519509
return self._flattened_sub_element_mapping
520510

521-
def extract_subelement_component(self, i):
511+
def extract_subelement_component(self, i, domain=None):
522512
"""Extract direct subelement index and subelement relative.
523513
524514
component index for a given component index.
525515
"""
526516
if isinstance(i, int):
527517
i = (i,)
528-
self._check_component(i)
518+
self._check_component(i, domain)
529519

530-
i = self.symmetry().get(i, i)
531-
l = len(self._shape) # noqa: E741
532-
ii = i[:l]
533-
jj = i[l:]
520+
i = self.symmetry(domain).get(i, i)
521+
rank = len(self._shape)
522+
ii = i[:rank]
523+
jj = i[rank:]
534524
if ii not in self._sub_element_mapping:
535525
raise ValueError(f"Illegal component index {i}.")
536526
k = self._sub_element_mapping[ii]
537527
return (k, jj)
538528

539-
def symmetry(self):
529+
def symmetry(self, domain=None):
540530
r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`.
541531
542532
meaning that component :math:`c_0` is represented by component

0 commit comments

Comments
 (0)