1313
1414import numpy as np
1515
16- from ufl .cell import as_cell
16+ from ufl .cell import CellSequence , as_cell
17+ from ufl .domain import MeshSequence
1718from finat .ufl .finiteelement import FiniteElement
1819from finat .ufl .finiteelementbase import FiniteElementBase
1920from ufl .permutation import compute_indices
@@ -26,7 +27,7 @@ class MixedElement(FiniteElementBase):
2627 """A finite element composed of a nested hierarchy of mixed or simple elements."""
2728 __slots__ = ("_sub_elements" , "_cells" )
2829
29- def __init__ (self , * elements , ** kwargs ):
30+ def __init__ (self , * elements , make_cell_sequence = True , ** kwargs ):
3031 """Create mixed finite element from given list of elements."""
3132 if type (self ) is MixedElement :
3233 if kwargs :
@@ -39,18 +40,15 @@ def __init__(self, *elements, **kwargs):
3940 elements = [MixedElement (e ) if isinstance (e , (tuple , list )) else e
4041 for e in elements ]
4142 self ._sub_elements = elements
42-
43- # Pick the first cell, for now all should be equal
44- cells = tuple (sorted (set (element .cell for element in elements ) - set ([None ])))
45- self ._cells = cells
43+ cells = tuple (e .cell for e in elements )
4644 if cells :
47- cell = cells [0 ]
48- # Require that all elements are defined on the same cell
49- if not all (c == cell for c in cells [1 :]):
50- raise ValueError ("Sub elements must live on the same cell." )
45+ if make_cell_sequence :
46+ cell = CellSequence (cells )
47+ else :
48+ # VectorElement or TensorElement.
49+ cell , = set (cells )
5150 else :
5251 cell = None
53-
5452 # Check that all elements use the same quadrature scheme TODO:
5553 # We can allow the scheme not to be defined.
5654 if len (elements ) == 0 :
@@ -94,6 +92,8 @@ def symmetry(self, domain):
9492 :math:`c_1`.
9593 A component is a tuple of one or more ints.
9694 """
95+ if isinstance (domain , MeshSequence ):
96+ raise NotImplementedError
9797 # Build symmetry map from symmetries of subelements
9898 sm = {}
9999 # Base index of the current subelement into mixed value
@@ -140,6 +140,8 @@ def extract_subelement_component(self, domain, i):
140140
141141 component index for a given component index.
142142 """
143+ if isinstance (domain , MeshSequence ):
144+ raise NotImplementedError
143145 if isinstance (i , int ):
144146 i = (i ,)
145147 self ._check_component (i )
@@ -245,7 +247,26 @@ def embedded_superdegree(self):
245247
246248 def reconstruct (self , ** kwargs ):
247249 """Doc."""
248- return MixedElement (* [e .reconstruct (** kwargs ) for e in self .sub_elements ])
250+ cell = kwargs .pop ('cell' , None )
251+ if isinstance (self .cell , CellSequence ):
252+ if cell is None :
253+ cell = self .cell
254+ else :
255+ if not isinstance (cell , CellSequence ):
256+ # Allow for passing a single base cell.
257+ cell = CellSequence ([cell ] * len (self .sub_elements ))
258+ cells = cell .cells
259+ else :
260+ if cell is None :
261+ cell = self .cell
262+ else :
263+ if isinstance (cell , CellSequence ):
264+ raise TypeError (f"Input cell(={ cell } ) must not be CellSequence" )
265+ cells = [cell ] * len (self .sub_elements )
266+ return MixedElement (
267+ * [e .reconstruct (cell = c , ** kwargs ) for c , e in zip (cells , self .sub_elements )],
268+ make_cell_sequence = isinstance (self .cell , CellSequence ),
269+ )
249270
250271 def variant (self ):
251272 """Doc."""
@@ -307,8 +328,11 @@ def __init__(self, family, cell=None, degree=None, dim=None,
307328 reference_value_shape = (dim ,) + sub_element .reference_value_shape
308329
309330 # Initialize element data
310- MixedElement .__init__ (self , sub_elements ,
311- reference_value_shape = reference_value_shape )
331+ MixedElement .__init__ (
332+ self , sub_elements ,
333+ reference_value_shape = reference_value_shape ,
334+ make_cell_sequence = False ,
335+ )
312336
313337 FiniteElementBase .__init__ (self , sub_element .family (), sub_element .cell , sub_element .degree (),
314338 sub_element .quadrature_scheme (), reference_value_shape )
@@ -435,8 +459,11 @@ def __init__(self, family, cell=None, degree=None, shape=None,
435459
436460 reference_value_shape = reference_value_shape + sub_element .reference_value_shape
437461 # Initialize element data
438- MixedElement .__init__ (self , sub_elements ,
439- reference_value_shape = reference_value_shape )
462+ MixedElement .__init__ (
463+ self , sub_elements ,
464+ reference_value_shape = reference_value_shape ,
465+ make_cell_sequence = False ,
466+ )
440467 self ._family = sub_element .family ()
441468 self ._degree = sub_element .degree ()
442469 self ._sub_element = sub_element
0 commit comments