Skip to content

Commit 752703e

Browse files
committed
mixed_element: use CellSequence()
1 parent 1d43334 commit 752703e

File tree

1 file changed

+46
-19
lines changed

1 file changed

+46
-19
lines changed

finat/ufl/mixedelement.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
import numpy as np
1515

16-
from ufl.cell import as_cell
16+
from ufl.cell import CellSequence, as_cell
17+
from ufl.domain import MeshSequence
1718
from finat.ufl.finiteelement import FiniteElement
1819
from finat.ufl.finiteelementbase import FiniteElementBase
1920
from ufl.permutation import compute_indices
@@ -39,18 +40,6 @@ def __init__(self, *elements, **kwargs):
3940
elements = [MixedElement(e) if isinstance(e, (tuple, list)) else e
4041
for e in elements]
4142
self._sub_elements = elements
42-
43-
# Pick the first cell, for now all should be equal
44-
cells = tuple(sorted(set(element.cell for element in elements) - set([None])))
45-
self._cells = cells
46-
if cells:
47-
cell = cells[0]
48-
# Require that all elements are defined on the same cell
49-
if not all(c == cell for c in cells[1:]):
50-
raise ValueError("Sub elements must live on the same cell.")
51-
else:
52-
cell = None
53-
5443
# Check that all elements use the same quadrature scheme TODO:
5544
# We can allow the scheme not to be defined.
5645
if len(elements) == 0:
@@ -70,9 +59,16 @@ def __init__(self, *elements, **kwargs):
7059
# Initialize element data
7160
degrees = {e.degree() for e in self._sub_elements} - {None}
7261
degree = max_degree(degrees) if degrees else None
73-
FiniteElementBase.__init__(self, "Mixed", cell, degree, quad_scheme,
62+
FiniteElementBase.__init__(self, "Mixed", self._make_cell(), degree, quad_scheme,
7463
reference_value_shape)
7564

65+
def _make_cell(self):
66+
if self.num_sub_elements == 0:
67+
return
68+
else:
69+
cells = tuple(e.cell for e in self.sub_elements)
70+
return CellSequence(cells)
71+
7672
def __repr__(self):
7773
"""Doc."""
7874
return "MixedElement(" + ", ".join(repr(e) for e in self._sub_elements) + ")"
@@ -94,6 +90,8 @@ def symmetry(self, domain):
9490
:math:`c_1`.
9591
A component is a tuple of one or more ints.
9692
"""
93+
if isinstance(domain, MeshSequence):
94+
raise NotImplementedError
9795
# Build symmetry map from symmetries of subelements
9896
sm = {}
9997
# Base index of the current subelement into mixed value
@@ -140,6 +138,8 @@ def extract_subelement_component(self, domain, i):
140138
141139
component index for a given component index.
142140
"""
141+
if isinstance(domain, MeshSequence):
142+
raise NotImplementedError
143143
if isinstance(i, int):
144144
i = (i,)
145145
self._check_component(i)
@@ -245,7 +245,16 @@ def embedded_superdegree(self):
245245

246246
def reconstruct(self, **kwargs):
247247
"""Doc."""
248-
return MixedElement(*[e.reconstruct(**kwargs) for e in self.sub_elements])
248+
cell = kwargs.pop('cell', None)
249+
if cell is None:
250+
cell = self.cell
251+
else:
252+
if not isinstance(cell, CellSequence):
253+
# Allow for passing a single base cell.
254+
cell = CellSequence([cell] * len(self.sub_elements))
255+
return type(self)(
256+
*[e.reconstruct(cell=c, **kwargs) for c, e in zip(cell.cells, self.sub_elements)],
257+
)
249258

250259
def variant(self):
251260
"""Doc."""
@@ -307,8 +316,10 @@ def __init__(self, family, cell=None, degree=None, dim=None,
307316
reference_value_shape = (dim,) + sub_element.reference_value_shape
308317

309318
# Initialize element data
310-
MixedElement.__init__(self, sub_elements,
311-
reference_value_shape=reference_value_shape)
319+
MixedElement.__init__(
320+
self, sub_elements,
321+
reference_value_shape=reference_value_shape,
322+
)
312323

313324
FiniteElementBase.__init__(self, sub_element.family(), sub_element.cell, sub_element.degree(),
314325
sub_element.quadrature_scheme(), reference_value_shape)
@@ -323,6 +334,13 @@ def __init__(self, family, cell=None, degree=None, dim=None,
323334
# Cache repr string
324335
self._repr = f"VectorElement({repr(sub_element)}, dim={dim}{var_str})"
325336

337+
def _make_cell(self):
338+
if self.num_sub_elements == 0:
339+
return
340+
else:
341+
cell, = set(e.cell for e in self.sub_elements)
342+
return cell
343+
326344
def __repr__(self):
327345
"""Doc."""
328346
return self._repr
@@ -435,8 +453,10 @@ def __init__(self, family, cell=None, degree=None, shape=None,
435453

436454
reference_value_shape = reference_value_shape + sub_element.reference_value_shape
437455
# Initialize element data
438-
MixedElement.__init__(self, sub_elements,
439-
reference_value_shape=reference_value_shape)
456+
MixedElement.__init__(
457+
self, sub_elements,
458+
reference_value_shape=reference_value_shape,
459+
)
440460
self._family = sub_element.family()
441461
self._degree = sub_element.degree()
442462
self._sub_element = sub_element
@@ -454,6 +474,13 @@ def __init__(self, family, cell=None, degree=None, shape=None,
454474
self._repr = (f"TensorElement({repr(sub_element)}, shape={shape}, "
455475
f"symmetry={symmetry}{var_str})")
456476

477+
def _make_cell(self):
478+
if self.num_sub_elements == 0:
479+
return
480+
else:
481+
cell, = set(e.cell for e in self.sub_elements)
482+
return cell
483+
457484
@property
458485
def pullback(self):
459486
"""Get pull back."""

0 commit comments

Comments
 (0)