Skip to content

Commit 67ed752

Browse files
connorjwarddham
andauthored
Fix assemble with multiple kernels and subdomains (#3135)
* Fix issue 3125 --------- Co-authored-by: David A. Ham <[email protected]>
1 parent e884121 commit 67ed752

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

firedrake/assemble.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ def replace_tensor(self, tensor):
692692
return
693693

694694
# TODO We should have some proper checks here
695-
for lknl, parloop in zip(self.local_kernels, self.parloops):
695+
for (lknl, _), parloop in zip(self.local_kernels, self.parloops):
696696
data = _FormHandler.index_tensor(tensor, self._form, lknl.indices, self.diagonal)
697697
parloop.arguments[0].data = data
698698
self._tensor = tensor
@@ -703,6 +703,14 @@ def execute_parloops(self):
703703

704704
@cached_property
705705
def local_kernels(self):
706+
"""Return local kernels and their subdomain IDs.
707+
708+
Returns
709+
-------
710+
tuple
711+
Collection of ``(local_kernel, subdomain_id)`` 2-tuples, one for
712+
each possible combination.
713+
"""
706714
try:
707715
topology, = set(d.topology for d in self._form.ufl_domains())
708716
except ValueError:
@@ -714,16 +722,26 @@ def local_kernels(self):
714722
raise NotImplementedError("Assembly with multiple meshes is not supported")
715723

716724
if isinstance(self._form, ufl.Form):
717-
return tsfc_interface.compile_form(self._form, "form", diagonal=self.diagonal,
718-
parameters=self._form_compiler_params)
725+
kernels = tsfc_interface.compile_form(
726+
self._form, "form", diagonal=self.diagonal,
727+
parameters=self._form_compiler_params
728+
)
719729
elif isinstance(self._form, slate.TensorBase):
720-
return slac.compile_expression(self._form, compiler_parameters=self._form_compiler_params)
730+
kernels = slac.compile_expression(
731+
self._form,
732+
compiler_parameters=self._form_compiler_params
733+
)
721734
else:
722735
raise AssertionError
736+
return tuple(
737+
(k, subdomain_id) for k in kernels for subdomain_id in k.kinfo.subdomain_id
738+
)
723739

724740
@cached_property
725741
def all_integer_subdomain_ids(self):
726-
return tsfc_interface.gather_integer_subdomain_ids(self.local_kernels)
742+
return tsfc_interface.gather_integer_subdomain_ids(
743+
{k for k, _ in self.local_kernels}
744+
)
727745

728746
@cached_property
729747
def global_kernels(self):
@@ -732,20 +750,28 @@ def global_kernels(self):
732750
self._form, tsfc_knl, subdomain_id, self.all_integer_subdomain_ids,
733751
diagonal=self.diagonal, unroll=self.needs_unrolling(tsfc_knl, self._bcs)
734752
)
735-
for tsfc_knl in self.local_kernels
736-
for subdomain_id in tsfc_knl.kinfo.subdomain_id
753+
for tsfc_knl, subdomain_id in self.local_kernels
737754
)
738755

739756
@cached_property
740757
def parloops(self):
741-
return tuple(
742-
ParloopBuilder(
743-
self._form, lknl, gknl, self._tensor, subdomain_id,
744-
self.all_integer_subdomain_ids, diagonal=self.diagonal,
745-
lgmaps=self.collect_lgmaps(lknl, self._bcs)).build()
746-
for lknl, gknl in zip(self.local_kernels, self.global_kernels)
747-
for subdomain_id in lknl.kinfo.subdomain_id
748-
)
758+
loops = []
759+
for (local_kernel, subdomain_id), global_kernel in zip(
760+
self.local_kernels, self.global_kernels
761+
):
762+
loops.append(
763+
ParloopBuilder(
764+
self._form,
765+
local_kernel,
766+
global_kernel,
767+
self._tensor,
768+
subdomain_id,
769+
self.all_integer_subdomain_ids,
770+
diagonal=self.diagonal,
771+
lgmaps=self.collect_lgmaps(local_kernel, self._bcs)
772+
).build()
773+
)
774+
return tuple(loops)
749775

750776
def needs_unrolling(self, local_knl, bcs):
751777
"""Do we need to address matrix elements directly rather than in

tests/regression/test_assemble.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,19 @@ def test_assemble_mixed_function_sparse():
266266
f.sub(4).interpolate(Constant(3.0))
267267
v = assemble((inner(f[1], f[1]) + inner(f[4], f[4])) * dx)
268268
assert np.allclose(v, 13.0)
269+
270+
271+
def test_3125():
272+
# see https://github.com/firedrakeproject/firedrake/issues/3125
273+
mesh = UnitSquareMesh(3, 3)
274+
V = VectorFunctionSpace(mesh, "CG", 2)
275+
W = FunctionSpace(mesh, "CG", 1)
276+
Z = MixedFunctionSpace([V, W])
277+
z = Function(Z)
278+
u, p = split(z)
279+
tst = TestFunction(Z)
280+
v, q = split(tst)
281+
d = Function(W)
282+
F = inner(z, tst)*dx + inner(u, v)/(d+p)*dx(2, degree=10)
283+
# should run without error
284+
solve(F == 0, z)

0 commit comments

Comments
 (0)