@@ -49,7 +49,7 @@ def __init__(self, *elements, **kwargs):
4949 if not all (e .quadrature_scheme () == quad_scheme for e in elements ):
5050 raise ValueError ("Quadrature scheme mismatch for sub elements of mixed element." )
5151
52- # Compute value sizes in global and reference configurations
52+ # Compute value sizes in reference configuration
5353 reference_value_size_sum = sum (product (s .reference_value_shape ) for s in self ._sub_elements )
5454
5555 # Default reference value shape: Treated simply as all
@@ -71,7 +71,7 @@ def _make_cell(self):
7171
7272 def __repr__ (self ):
7373 """Doc."""
74- return "MixedElement(" + ", " .join (repr ( e ) for e in self ._sub_elements ) + ")"
74+ return "MixedElement(" + ", " .join (map ( repr , self ._sub_elements ) ) + ")"
7575
7676 def _is_linear (self ):
7777 """Doc."""
@@ -83,7 +83,7 @@ def reconstruct_from_elements(self, *elements):
8383 return self
8484 return MixedElement (* elements )
8585
86- def symmetry (self , domain ):
86+ def symmetry (self , domain = None ):
8787 r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`.
8888
8989 meaning that component :math:`c_0` is represented by component
@@ -101,15 +101,15 @@ def symmetry(self, domain):
101101 st = shape_to_strides (sh )
102102 # Map symmetries of subelement into index space of this
103103 # element
104- for c0 , c1 in e .symmetry ().items ():
104+ for c0 , c1 in e .symmetry (domain ).items ():
105105 j0 = flatten_multiindex (c0 , st ) + j
106106 j1 = flatten_multiindex (c1 , st ) + j
107107 sm [(j0 ,)] = (j1 ,)
108108 # Update base index for next element
109109 j += product (sh )
110110 if j != product (self .value_shape (domain )):
111111 raise ValueError ("Size mismatch in symmetry algorithm." )
112- return sm or {}
112+ return sm
113113
114114 @property
115115 def sobolev_space (self ):
@@ -133,7 +133,7 @@ def sub_elements(self):
133133 """Return list of sub elements."""
134134 return self ._sub_elements
135135
136- def extract_subelement_component (self , domain , i ):
136+ def extract_subelement_component (self , i , domain = None ):
137137 """Extract direct subelement index and subelement relative.
138138
139139 component index for a given component index.
@@ -142,13 +142,14 @@ def extract_subelement_component(self, domain, i):
142142 raise NotImplementedError
143143 if isinstance (i , int ):
144144 i = (i ,)
145- self ._check_component (i )
145+ self ._check_component (i , domain )
146146
147147 # Select between indexing modes
148148 if len (self .value_shape (domain )) == 1 :
149149 # Indexing into a long vector of flattened subelement
150150 # shapes
151151 j , = i
152+ j = int (j )
152153
153154 # Find subelement for this index
154155 for sub_element_index , e in enumerate (self ._sub_elements ):
@@ -172,13 +173,13 @@ def extract_subelement_component(self, domain, i):
172173 component = i [1 :]
173174 return (sub_element_index , component )
174175
175- def extract_component (self , i ):
176+ def extract_component (self , i , domain = None ):
176177 """Recursively extract component index relative to a (simple) element.
177178
178179 and that element for given value component index.
179180 """
180- sub_element_index , component = self .extract_subelement_component (i )
181- return self ._sub_elements [sub_element_index ].extract_component (component )
181+ sub_element_index , component = self .extract_subelement_component (i , domain )
182+ return self ._sub_elements [sub_element_index ].extract_component (component , domain )
182183
183184 def extract_subelement_reference_component (self , i ):
184185 """Extract direct subelement index and subelement relative.
@@ -193,6 +194,7 @@ def extract_subelement_reference_component(self, i):
193194 assert len (self .reference_value_shape ) == 1
194195 # Indexing into a long vector of flattened subelement shapes
195196 j , = i
197+ j = int (j )
196198
197199 # Find subelement for this index
198200 for sub_element_index , e in enumerate (self ._sub_elements ):
@@ -217,20 +219,20 @@ def extract_reference_component(self, i):
217219 sub_element_index , reference_component = self .extract_subelement_reference_component (i )
218220 return self ._sub_elements [sub_element_index ].extract_reference_component (reference_component )
219221
220- def is_cellwise_constant (self , component = None ):
222+ def is_cellwise_constant (self , component = None , domain = None ):
221223 """Return whether the basis functions of this element is spatially constant over each cell."""
222224 if component is None :
223225 return all (e .is_cellwise_constant () for e in self .sub_elements )
224226 else :
225- i , e = self .extract_component (component )
227+ i , e = self .extract_component (component , domain )
226228 return e .is_cellwise_constant ()
227229
228- def degree (self , component = None ):
230+ def degree (self , component = None , domain = None ):
229231 """Return polynomial degree of finite element."""
230232 if component is None :
231233 return self ._degree # from FiniteElementBase, computed as max of subelements in __init__
232234 else :
233- i , e = self .extract_component (component )
235+ i , e = self .extract_component (component , domain )
234236 return e .degree ()
235237
236238 @property
@@ -244,7 +246,7 @@ def embedded_superdegree(self):
244246 return max (e .embedded_superdegree for e in self .sub_elements )
245247
246248 def reconstruct (self , ** kwargs ):
247- """Doc ."""
249+ """Construct a new FiniteElement object with some properties replaced with new values ."""
248250 cell = kwargs .pop ('cell' , None )
249251 if cell is None :
250252 cell = self .cell
@@ -257,7 +259,7 @@ def reconstruct(self, **kwargs):
257259 )
258260
259261 def variant (self ):
260- """Doc ."""
262+ """Return the common variant to all subelements ."""
261263 try :
262264 variant , = {e .variant () for e in self .sub_elements }
263265 return variant
@@ -266,7 +268,7 @@ def variant(self):
266268
267269 def __str__ (self ):
268270 """Format as string for pretty printing."""
269- tmp = ", " .join (str ( element ) for element in self ._sub_elements )
271+ tmp = ", " .join (map ( str , self ._sub_elements ) )
270272 return "<Mixed element: (" + tmp + ")>"
271273
272274 def shortstr (self ):
@@ -291,7 +293,6 @@ def __init__(self, family, cell=None, degree=None, dim=None,
291293 if isinstance (family , FiniteElementBase ):
292294 sub_element = family
293295 cell = sub_element .cell
294- variant = sub_element .variant ()
295296 else :
296297 if cell is not None :
297298 cell = as_cell (cell )
@@ -326,13 +327,8 @@ def __init__(self, family, cell=None, degree=None, dim=None,
326327
327328 self ._sub_element = sub_element
328329
329- if variant is None :
330- var_str = ""
331- else :
332- var_str = ", variant='" + variant + "'"
333-
334330 # Cache repr string
335- self ._repr = f"VectorElement({ repr (sub_element )} , dim={ dim } { var_str } )"
331+ self ._repr = f"VectorElement({ repr (sub_element )} , dim={ dim } )"
336332
337333 def _make_cell (self ):
338334 if self .num_sub_elements == 0 :
@@ -388,7 +384,6 @@ def __init__(self, family, cell=None, degree=None, shape=None,
388384 if isinstance (family , FiniteElementBase ):
389385 sub_element = family
390386 cell = sub_element .cell
391- variant = sub_element .variant ()
392387 else :
393388 if cell is not None :
394389 cell = as_cell (cell )
@@ -435,7 +430,7 @@ def __init__(self, family, cell=None, degree=None, shape=None,
435430 if index in symmetry :
436431 continue
437432 sub_element_mapping [index ] = len (sub_elements )
438- sub_elements += [ sub_element ]
433+ sub_elements . append ( sub_element )
439434
440435 # Update mapping for symmetry
441436 for index in indices :
@@ -466,14 +461,9 @@ def __init__(self, family, cell=None, degree=None, shape=None,
466461 self ._sub_element_mapping = sub_element_mapping
467462 self ._flattened_sub_element_mapping = flattened_sub_element_mapping
468463
469- if variant is None :
470- var_str = ""
471- else :
472- var_str = ", variant='" + variant + "'"
473-
474464 # Cache repr string
475465 self ._repr = (f"TensorElement({ repr (sub_element )} , shape={ shape } , "
476- f"symmetry={ symmetry } { var_str } )" )
466+ f"symmetry={ symmetry } )" )
477467
478468 def _make_cell (self ):
479469 if self .num_sub_elements == 0 :
@@ -518,25 +508,25 @@ def flattened_sub_element_mapping(self):
518508 """Doc."""
519509 return self ._flattened_sub_element_mapping
520510
521- def extract_subelement_component (self , i ):
511+ def extract_subelement_component (self , i , domain = None ):
522512 """Extract direct subelement index and subelement relative.
523513
524514 component index for a given component index.
525515 """
526516 if isinstance (i , int ):
527517 i = (i ,)
528- self ._check_component (i )
518+ self ._check_component (i , domain )
529519
530- i = self .symmetry ().get (i , i )
531- l = len (self ._shape ) # noqa: E741
532- ii = i [:l ]
533- jj = i [l :]
520+ i = self .symmetry (domain ).get (i , i )
521+ rank = len (self ._shape )
522+ ii = i [:rank ]
523+ jj = i [rank :]
534524 if ii not in self ._sub_element_mapping :
535525 raise ValueError (f"Illegal component index { i } ." )
536526 k = self ._sub_element_mapping [ii ]
537527 return (k , jj )
538528
539- def symmetry (self ):
529+ def symmetry (self , domain = None ):
540530 r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`.
541531
542532 meaning that component :math:`c_0` is represented by component
0 commit comments