1313
1414import numpy as np
1515
16- from ufl .cell import as_cell
16+ from ufl .cell import CellSequence , as_cell
17+ from ufl .domain import MeshSequence
1718from finat .ufl .finiteelement import FiniteElement
1819from finat .ufl .finiteelementbase import FiniteElementBase
1920from ufl .permutation import compute_indices
@@ -39,18 +40,6 @@ def __init__(self, *elements, **kwargs):
3940 elements = [MixedElement (e ) if isinstance (e , (tuple , list )) else e
4041 for e in elements ]
4142 self ._sub_elements = elements
42-
43- # Pick the first cell, for now all should be equal
44- cells = tuple (sorted (set (element .cell for element in elements ) - set ([None ])))
45- self ._cells = cells
46- if cells :
47- cell = cells [0 ]
48- # Require that all elements are defined on the same cell
49- if not all (c == cell for c in cells [1 :]):
50- raise ValueError ("Sub elements must live on the same cell." )
51- else :
52- cell = None
53-
5443 # Check that all elements use the same quadrature scheme TODO:
5544 # We can allow the scheme not to be defined.
5645 if len (elements ) == 0 :
@@ -70,9 +59,16 @@ def __init__(self, *elements, **kwargs):
7059 # Initialize element data
7160 degrees = {e .degree () for e in self ._sub_elements } - {None }
7261 degree = max_degree (degrees ) if degrees else None
73- FiniteElementBase .__init__ (self , "Mixed" , cell , degree , quad_scheme ,
62+ FiniteElementBase .__init__ (self , "Mixed" , self . _make_cell () , degree , quad_scheme ,
7463 reference_value_shape )
7564
65+ def _make_cell (self ):
66+ if self .num_sub_elements == 0 :
67+ return
68+ else :
69+ cells = tuple (e .cell for e in self .sub_elements )
70+ return CellSequence (cells )
71+
7672 def __repr__ (self ):
7773 """Doc."""
7874 return "MixedElement(" + ", " .join (repr (e ) for e in self ._sub_elements ) + ")"
@@ -94,6 +90,8 @@ def symmetry(self, domain):
9490 :math:`c_1`.
9591 A component is a tuple of one or more ints.
9692 """
93+ if isinstance (domain , MeshSequence ):
94+ raise NotImplementedError
9795 # Build symmetry map from symmetries of subelements
9896 sm = {}
9997 # Base index of the current subelement into mixed value
@@ -140,6 +138,8 @@ def extract_subelement_component(self, domain, i):
140138
141139 component index for a given component index.
142140 """
141+ if isinstance (domain , MeshSequence ):
142+ raise NotImplementedError
143143 if isinstance (i , int ):
144144 i = (i ,)
145145 self ._check_component (i )
@@ -245,7 +245,16 @@ def embedded_superdegree(self):
245245
246246 def reconstruct (self , ** kwargs ):
247247 """Doc."""
248- return MixedElement (* [e .reconstruct (** kwargs ) for e in self .sub_elements ])
248+ cell = kwargs .pop ('cell' , None )
249+ if cell is None :
250+ cell = self .cell
251+ else :
252+ if not isinstance (cell , CellSequence ):
253+ # Allow for passing a single base cell.
254+ cell = CellSequence ([cell ] * len (self .sub_elements ))
255+ return type (self )(
256+ * [e .reconstruct (cell = c , ** kwargs ) for c , e in zip (cell .cells , self .sub_elements )],
257+ )
249258
250259 def variant (self ):
251260 """Doc."""
@@ -307,8 +316,10 @@ def __init__(self, family, cell=None, degree=None, dim=None,
307316 reference_value_shape = (dim ,) + sub_element .reference_value_shape
308317
309318 # Initialize element data
310- MixedElement .__init__ (self , sub_elements ,
311- reference_value_shape = reference_value_shape )
319+ MixedElement .__init__ (
320+ self , sub_elements ,
321+ reference_value_shape = reference_value_shape ,
322+ )
312323
313324 FiniteElementBase .__init__ (self , sub_element .family (), sub_element .cell , sub_element .degree (),
314325 sub_element .quadrature_scheme (), reference_value_shape )
@@ -323,6 +334,13 @@ def __init__(self, family, cell=None, degree=None, dim=None,
323334 # Cache repr string
324335 self ._repr = f"VectorElement({ repr (sub_element )} , dim={ dim } { var_str } )"
325336
337+ def _make_cell (self ):
338+ if self .num_sub_elements == 0 :
339+ return
340+ else :
341+ cell , = set (e .cell for e in self .sub_elements )
342+ return cell
343+
326344 def __repr__ (self ):
327345 """Doc."""
328346 return self ._repr
@@ -435,8 +453,10 @@ def __init__(self, family, cell=None, degree=None, shape=None,
435453
436454 reference_value_shape = reference_value_shape + sub_element .reference_value_shape
437455 # Initialize element data
438- MixedElement .__init__ (self , sub_elements ,
439- reference_value_shape = reference_value_shape )
456+ MixedElement .__init__ (
457+ self , sub_elements ,
458+ reference_value_shape = reference_value_shape ,
459+ )
440460 self ._family = sub_element .family ()
441461 self ._degree = sub_element .degree ()
442462 self ._sub_element = sub_element
@@ -454,6 +474,13 @@ def __init__(self, family, cell=None, degree=None, shape=None,
454474 self ._repr = (f"TensorElement({ repr (sub_element )} , shape={ shape } , "
455475 f"symmetry={ symmetry } { var_str } )" )
456476
477+ def _make_cell (self ):
478+ if self .num_sub_elements == 0 :
479+ return
480+ else :
481+ cell , = set (e .cell for e in self .sub_elements )
482+ return cell
483+
457484 @property
458485 def pullback (self ):
459486 """Get pull back."""
0 commit comments