Skip to content

Commit e55085b

Browse files
Fix for vom-to-vom permutation matrix on Vector-valued function spaces (#4510)
* fix readonly=True * remove _get_sizes; save nleaves as attribute * add tensor-valued test --------- Co-authored-by: Pablo Brubeck <[email protected]>
1 parent 29e2ed7 commit e55085b

File tree

2 files changed

+65
-20
lines changed

2 files changed

+65
-20
lines changed

firedrake/interpolation.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,10 +1521,13 @@ def __init__(self, sf, forward_reduce, V, source_vom, expr, arguments):
15211521
self.arguments = arguments
15221522
# Calculate correct local and global sizes for the matrix
15231523
nroots, leaves, _ = sf.getGraph()
1524-
nleaves = len(leaves)
1524+
self.nleaves = len(leaves)
15251525
self._local_sizes = V.comm.allgather(nroots)
1526-
self.source_size = (nroots, sum(self._local_sizes))
1527-
self.target_size = (nleaves, self.V.comm.allreduce(nleaves, op=MPI.SUM))
1526+
self.source_size = (self.V.block_size * nroots, self.V.block_size * sum(self._local_sizes))
1527+
self.target_size = (
1528+
self.V.block_size * self.nleaves,
1529+
self.V.block_size * V.comm.allreduce(self.nleaves, op=MPI.SUM),
1530+
)
15281531

15291532
@property
15301533
def mpi_type(self):
@@ -1565,7 +1568,7 @@ def expr_as_coeff(self, source_vec=None):
15651568
raise ValueError("Need to provide a source dat for the argument!")
15661569
arg = self.arguments[0]
15671570
arg_coeff = firedrake.Function(arg.function_space())
1568-
arg_coeff.dat.data_wo[:] = source_vec.getArray().reshape(
1571+
arg_coeff.dat.data_wo[:] = source_vec.getArray(readonly=True).reshape(
15691572
arg_coeff.dat.data_wo.shape
15701573
)
15711574
coeff_expr = ufl.replace(self.expr, {arg: arg_coeff})
@@ -1643,14 +1646,6 @@ def multTranspose(self, mat, source_vec, target_vec):
16431646
target_vec.zeroEntries()
16441647
self.reduce(source_vec, target_vec)
16451648

1646-
def _get_sizes(self):
1647-
nroots, leaves, _ = self.sf.getGraph()
1648-
nleaves = len(leaves)
1649-
local_sizes = self.V.comm.allgather(nroots)
1650-
source_size = (nroots, sum(local_sizes))
1651-
target_size = (nleaves, self.V.comm.allreduce(nleaves, op=MPI.SUM))
1652-
return source_size, target_size
1653-
16541649
def _create_permutation_mat(self):
16551650
"""Creates the PETSc matrix that represents the interpolation operator from a vertex-only mesh to
16561651
its input ordering vertex-only mesh"""
@@ -1659,25 +1654,23 @@ def _create_permutation_mat(self):
16591654
start = sum(self._local_sizes[:self.V.comm.rank])
16601655
end = start + self.source_size[0]
16611656
contiguous_indices = numpy.arange(start, end, dtype=utils.IntType)
1662-
perm = numpy.zeros(self.target_size[0], dtype=utils.IntType)
1657+
perm = numpy.zeros(self.nleaves, dtype=utils.IntType)
16631658
self.sf.bcastBegin(MPI.INT, contiguous_indices, perm, MPI.REPLACE)
16641659
self.sf.bcastEnd(MPI.INT, contiguous_indices, perm, MPI.REPLACE)
16651660
rows = numpy.arange(self.target_size[0] + 1, dtype=utils.IntType)
1666-
mat.setValuesCSR(rows, perm, numpy.ones_like(perm, dtype=utils.IntType))
1661+
cols = (self.V.block_size * perm[:, None] + numpy.arange(self.V.block_size, dtype=utils.IntType)[None, :]).reshape(-1)
1662+
mat.setValuesCSR(rows, cols, numpy.ones_like(cols, dtype=utils.IntType))
16671663
mat.assemble()
16681664
if self.forward_reduce:
16691665
mat.transpose()
16701666
return mat
16711667

16721668
def _wrap_dummy_mat(self):
16731669
mat = PETSc.Mat().create(comm=self.V.comm)
1674-
dim = self.V.value_size
1675-
source_size = tuple(dim * i for i in self.source_size)
1676-
target_size = tuple(dim * i for i in self.target_size)
16771670
if self.forward_reduce:
1678-
mat_size = (source_size, target_size)
1671+
mat_size = (self.source_size, self.target_size)
16791672
else:
1680-
mat_size = (target_size, source_size)
1673+
mat_size = (self.target_size, self.source_size)
16811674
mat.setSizes(mat_size)
16821675
mat.setType(mat.Type.PYTHON)
16831676
mat.setPythonContext(self)

tests/firedrake/vertexonly/test_vertex_only_fs.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def functionspace_tests(vm, petsc_raises):
120120
idxs_to_include = input_ordering_parent_cell_nums != -1
121121
assert np.allclose(h.dat.data_ro_with_halos[idxs_to_include], np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension()), axis=1))
122122
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == -1)
123-
123+
# Using permutation matrix
124+
perm_mat = assemble(interpolate(TestFunction(V), W, matfree=False))
125+
h2 = assemble(perm_mat @ g)
126+
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
124127
# check other interpolation APIs work identically
125128
h2 = assemble(interpolate(g, W))
126129
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
@@ -225,6 +228,10 @@ def vectorfunctionspace_tests(vm, petsc_raises):
225228
idxs_to_include = input_ordering_parent_cell_nums != -1
226229
assert np.allclose(h.dat.data_ro[idxs_to_include], 2*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])
227230
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == -1)
231+
# Using permutation matrix
232+
perm_mat = assemble(interpolate(TestFunction(V), W, matfree=False))
233+
h2 = assemble(perm_mat @ g)
234+
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
228235
# check other interpolation APIs work identically
229236
h2 = assemble(interpolate(g, W))
230237
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
@@ -353,3 +360,48 @@ def test_input_ordering_missing_point():
353360
assert not len(data_input_ordering.dat.data_ro)
354361
# Accessing data_ro [*here] is collective, hence this redundant call
355362
_ = len(data_input_ordering.dat.data_ro)
363+
364+
365+
@pytest.fixture(
366+
params=[
367+
((2, 2), None),
368+
(None, True),
369+
((), None),
370+
((2, 3), None),
371+
]
372+
)
373+
def tensorfs_and_expr(request):
374+
shape, symmetry = request.param
375+
np.random.seed(0)
376+
mesh = UnitSquareMesh(2, 2)
377+
coords = np.random.random_sample(size=(10, 2))
378+
vom = VertexOnlyMesh(mesh, coords)
379+
380+
V = TensorFunctionSpace(vom, "DG", 0, shape=shape, symmetry=symmetry)
381+
W = TensorFunctionSpace(vom.input_ordering, "DG", 0, shape=shape, symmetry=symmetry)
382+
383+
x = SpatialCoordinate(vom)
384+
if shape == ():
385+
expr = inner(x, x)
386+
elif shape is None or shape == (2, 2):
387+
expr = outer(x, x) + Identity(2)
388+
elif shape == (2, 3):
389+
a = as_vector([x[0], x[1]])
390+
b = as_vector([x[0], x[1], Constant(1.0)])
391+
expr = outer(a, b)
392+
393+
return V, W, expr
394+
395+
396+
@pytest.mark.parallel([1, 3])
397+
def test_tensorfs_permutation(tensorfs_and_expr):
398+
V, W, expr = tensorfs_and_expr
399+
f = Function(V)
400+
f.interpolate(expr)
401+
f_in_W = assemble(interpolate(f, W))
402+
python_mat = assemble(interpolate(TestFunction(V), W, matfree=False))
403+
f_in_W_2 = assemble(python_mat @ f)
404+
assert np.allclose(f_in_W.dat.data_ro, f_in_W_2.dat.data_ro)
405+
petsc_mat = assemble(interpolate(TestFunction(V), W, matfree=True))
406+
f_in_W_petsc = assemble(petsc_mat @ f)
407+
assert np.allclose(f_in_W.dat.data_ro, f_in_W_petsc.dat.data_ro)

0 commit comments

Comments
 (0)