Skip to content

Commit cf696fe

Browse files
committed
mixed_element: use CellSequence()
1 parent d9ed734 commit cf696fe

File tree

1 file changed

+43
-16
lines changed

1 file changed

+43
-16
lines changed

finat/ufl/mixedelement.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
import numpy as np
1515

16-
from ufl.cell import as_cell
16+
from ufl.cell import CellSequence, as_cell
17+
from ufl.domain import MeshSequence
1718
from finat.ufl.finiteelement import FiniteElement
1819
from finat.ufl.finiteelementbase import FiniteElementBase
1920
from ufl.permutation import compute_indices
@@ -26,7 +27,7 @@ class MixedElement(FiniteElementBase):
2627
"""A finite element composed of a nested hierarchy of mixed or simple elements."""
2728
__slots__ = ("_sub_elements", "_cells")
2829

29-
def __init__(self, *elements, **kwargs):
30+
def __init__(self, *elements, make_cell_sequence=True, **kwargs):
3031
"""Create mixed finite element from given list of elements."""
3132
if type(self) is MixedElement:
3233
if kwargs:
@@ -39,18 +40,15 @@ def __init__(self, *elements, **kwargs):
3940
elements = [MixedElement(e) if isinstance(e, (tuple, list)) else e
4041
for e in elements]
4142
self._sub_elements = elements
42-
43-
# Pick the first cell, for now all should be equal
44-
cells = tuple(sorted(set(element.cell for element in elements) - set([None])))
45-
self._cells = cells
43+
cells = tuple(e.cell for e in elements)
4644
if cells:
47-
cell = cells[0]
48-
# Require that all elements are defined on the same cell
49-
if not all(c == cell for c in cells[1:]):
50-
raise ValueError("Sub elements must live on the same cell.")
45+
if make_cell_sequence:
46+
cell = CellSequence(cells)
47+
else:
48+
# VectorElement or TensorElement.
49+
cell, = set(cells)
5150
else:
5251
cell = None
53-
5452
# Check that all elements use the same quadrature scheme TODO:
5553
# We can allow the scheme not to be defined.
5654
if len(elements) == 0:
@@ -94,6 +92,8 @@ def symmetry(self, domain):
9492
:math:`c_1`.
9593
A component is a tuple of one or more ints.
9694
"""
95+
if isinstance(domain, MeshSequence):
96+
raise NotImplementedError
9797
# Build symmetry map from symmetries of subelements
9898
sm = {}
9999
# Base index of the current subelement into mixed value
@@ -140,6 +140,8 @@ def extract_subelement_component(self, domain, i):
140140
141141
component index for a given component index.
142142
"""
143+
if isinstance(domain, MeshSequence):
144+
raise NotImplementedError
143145
if isinstance(i, int):
144146
i = (i,)
145147
self._check_component(i)
@@ -245,7 +247,26 @@ def embedded_superdegree(self):
245247

246248
def reconstruct(self, **kwargs):
247249
"""Doc."""
248-
return MixedElement(*[e.reconstruct(**kwargs) for e in self.sub_elements])
250+
cell = kwargs.pop('cell', None)
251+
if isinstance(self.cell, CellSequence):
252+
if cell is None:
253+
cell = self.cell
254+
else:
255+
if not isinstance(cell, CellSequence):
256+
# Allow for passing a single base cell.
257+
cell = CellSequence([cell] * len(self.sub_elements))
258+
cells = cell.cells
259+
else:
260+
if cell is None:
261+
cell = self.cell
262+
else:
263+
if isinstance(cell, CellSequence):
264+
raise TypeError(f"Input cell(={cell}) must not be CellSequence")
265+
cells = [cell] * len(self.sub_elements)
266+
return MixedElement(
267+
*[e.reconstruct(cell=c, **kwargs) for c, e in zip(cells, self.sub_elements)],
268+
make_cell_sequence=isinstance(self.cell, CellSequence),
269+
)
249270

250271
def variant(self):
251272
"""Doc."""
@@ -307,8 +328,11 @@ def __init__(self, family, cell=None, degree=None, dim=None,
307328
reference_value_shape = (dim,) + sub_element.reference_value_shape
308329

309330
# Initialize element data
310-
MixedElement.__init__(self, sub_elements,
311-
reference_value_shape=reference_value_shape)
331+
MixedElement.__init__(
332+
self, sub_elements,
333+
reference_value_shape=reference_value_shape,
334+
make_cell_sequence=False,
335+
)
312336

313337
FiniteElementBase.__init__(self, sub_element.family(), sub_element.cell, sub_element.degree(),
314338
sub_element.quadrature_scheme(), reference_value_shape)
@@ -435,8 +459,11 @@ def __init__(self, family, cell=None, degree=None, shape=None,
435459

436460
reference_value_shape = reference_value_shape + sub_element.reference_value_shape
437461
# Initialize element data
438-
MixedElement.__init__(self, sub_elements,
439-
reference_value_shape=reference_value_shape)
462+
MixedElement.__init__(
463+
self, sub_elements,
464+
reference_value_shape=reference_value_shape,
465+
make_cell_sequence=False,
466+
)
440467
self._family = sub_element.family()
441468
self._degree = sub_element.degree()
442469
self._sub_element = sub_element

0 commit comments

Comments
 (0)