Skip to content

Commit 395b8ea

Browse files
authored
PMGPC: use native matrix-free (adjoint) interpolation (#4602)
* PMGPC: use native matrix-free (adjoint) interpolation
1 parent e8eb716 commit 395b8ea

File tree

2 files changed

+60
-171
lines changed

2 files changed

+60
-171
lines changed

firedrake/preconditioners/pmg.py

Lines changed: 52 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from firedrake.nullspace import VectorSpaceBasis, MixedVectorSpaceBasis
99
from firedrake.solving_utils import _SNESContext
1010
from firedrake.tsfc_interface import extract_numbered_coefficients
11-
from firedrake.utils import ScalarType_c, IntType_c, cached_property
11+
from firedrake.utils import IntType_c, cached_property
1212
from finat.element_factory import create_element
1313
from tsfc import compile_expression_dual_evaluation
1414
from pyop2 import op2
@@ -1220,44 +1220,57 @@ def work_function(self, V):
12201220

12211221
@cached_property
12221222
def _weight(self):
1223+
cell_set = self.Vf.mesh().topology.unique().cell_set
12231224
weight = firedrake.Function(self.Vf)
1224-
size = self.Vf.finat_element.space_dimension() * self.Vf.block_size
1225+
wsize = self.Vf.finat_element.space_dimension() * self.Vf.block_size
12251226
kernel_code = f"""
1226-
void weight(PetscScalar *restrict w){{
1227-
for(PetscInt i=0; i<{size}; i++) w[i] += 1.0;
1228-
return;
1229-
}}
1230-
"""
1231-
kernel = op2.Kernel(kernel_code, "weight", requires_zeroed_output_arguments=True)
1232-
op2.par_loop(kernel, weight.function_space().mesh().topology.unique().cell_set, weight.dat(op2.INC, weight.cell_node_map()))
1227+
void multiplicity(PetscScalar *restrict w) {{
1228+
for (PetscInt i=0; i<{wsize}; i++) w[i] += 1;
1229+
}}"""
1230+
kernel = op2.Kernel(kernel_code, "multiplicity")
1231+
op2.par_loop(kernel, cell_set, weight.dat(op2.INC, weight.cell_node_map()))
12331232
with weight.dat.vec as w:
12341233
w.reciprocal()
12351234
return weight
12361235

12371236
@cached_property
12381237
def _kernels(self):
12391238
try:
1240-
# We generate custom prolongation and restriction kernels mainly because:
1241-
# 1. Code generation for the transpose of prolongation is not readily available
1242-
# 2. Dual evaluation of EnrichedElement is not yet implemented in FInAT
1243-
uf_map = get_permuted_map(self.Vf)
1244-
uc_map = get_permuted_map(self.Vc)
1245-
prolong_kernel, restrict_kernel, coefficients = self.make_blas_kernels(self.Vf, self.Vc)
1246-
prolong_args = [prolong_kernel, self.uf.function_space().mesh().topology.unique().cell_set,
1247-
self.uf.dat(op2.INC, uf_map),
1248-
self.uc.dat(op2.READ, uc_map),
1249-
self._weight.dat(op2.READ, uf_map)]
1250-
except ValueError:
1251-
# The elements do not have the expected tensor product structure
1252-
# Fall back to aij kernels
1253-
uf_map = self.Vf.cell_node_map()
1254-
uc_map = self.Vc.cell_node_map()
1255-
prolong_kernel, restrict_kernel, coefficients = self.make_kernels(self.Vf, self.Vc)
1256-
prolong_args = [prolong_kernel, self.uf.function_space().mesh().topology.unique().cell_set,
1257-
self.uf.dat(op2.WRITE, uf_map),
1258-
self.uc.dat(op2.READ, uc_map)]
1259-
1260-
restrict_args = [restrict_kernel, self.uf.function_space().mesh().topology.unique().cell_set,
1239+
self.Vf.finat_element.dual_basis
1240+
self.Vc.finat_element.dual_basis
1241+
native_interpolation_supported = True
1242+
except NotImplementedError:
1243+
native_interpolation_supported = False
1244+
1245+
if native_interpolation_supported:
1246+
return self._build_native_interpolators()
1247+
else:
1248+
return self._build_custom_interpolators()
1249+
1250+
def _build_native_interpolators(self):
1251+
from firedrake.interpolation import interpolate, Interpolator
1252+
P = Interpolator(interpolate(self.uc, self.Vf), self.Vf)
1253+
prolong = partial(P.assemble, tensor=self.uf)
1254+
1255+
rf = firedrake.Function(self.Vf.dual(), val=self.uf.dat)
1256+
rc = firedrake.Function(self.Vc.dual(), val=self.uc.dat)
1257+
vc = firedrake.TestFunction(self.Vc)
1258+
R = Interpolator(interpolate(vc, rf), self.Vf)
1259+
restrict = partial(R.assemble, tensor=rc)
1260+
return prolong, restrict
1261+
1262+
def _build_custom_interpolators(self):
1263+
# We generate custom prolongation and restriction kernels because
1264+
# dual evaluation of EnrichedElement is not yet implemented in FInAT
1265+
uf_map = get_permuted_map(self.Vf)
1266+
uc_map = get_permuted_map(self.Vc)
1267+
prolong_kernel, restrict_kernel, coefficients = self.make_blas_kernels(self.Vf, self.Vc)
1268+
cell_set = self.Vf.mesh().topology.unique().cell_set
1269+
prolong_args = [prolong_kernel, cell_set,
1270+
self.uf.dat(op2.INC, uf_map),
1271+
self.uc.dat(op2.READ, uc_map),
1272+
self._weight.dat(op2.READ, uf_map)]
1273+
restrict_args = [restrict_kernel, cell_set,
12611274
self.uc.dat(op2.INC, uc_map),
12621275
self.uf.dat(op2.READ, uf_map),
12631276
self._weight.dat(op2.READ, uf_map)]
@@ -1444,49 +1457,6 @@ def make_blas_kernels(self, Vf, Vc):
14441457
ldargs=BLASLAPACK_LIB.split(), requires_zeroed_output_arguments=True)
14451458
return cache.setdefault(key, (prolong_kernel, restrict_kernel, coefficients))
14461459

1447-
def make_kernels(self, Vf, Vc):
1448-
"""
1449-
Interpolation and restriction kernels between arbitrary elements.
1450-
1451-
This is temporary while we wait for dual evaluation in FInAT.
1452-
"""
1453-
cache = self._cache_kernels
1454-
key = (Vf.ufl_element(), Vc.ufl_element())
1455-
try:
1456-
return cache[key]
1457-
except KeyError:
1458-
pass
1459-
prolong_kernel, _ = prolongation_transfer_kernel_action(Vf, self.uc)
1460-
matrix_kernel, coefficients = prolongation_transfer_kernel_action(Vf, firedrake.TrialFunction(Vc))
1461-
1462-
# The way we transpose the prolongation kernel is suboptimal.
1463-
# A local matrix is generated each time the kernel is executed.
1464-
element_kernel = cache_generate_code(matrix_kernel, Vf._comm)
1465-
element_kernel = element_kernel.replace("void expression_kernel", "static void expression_kernel")
1466-
coef_args = "".join([", c%d" % i for i in range(len(coefficients))])
1467-
coef_decl = "".join([", const %s *restrict c%d" % (ScalarType_c, i) for i in range(len(coefficients))])
1468-
dimc = Vc.finat_element.space_dimension() * Vc.block_size
1469-
dimf = Vf.finat_element.space_dimension() * Vf.block_size
1470-
restrict_code = f"""
1471-
{element_kernel}
1472-
1473-
void restriction({ScalarType_c} *restrict Rc, const {ScalarType_c} *restrict Rf, const {ScalarType_c} *restrict w{coef_decl})
1474-
{{
1475-
{ScalarType_c} Afc[{dimf}*{dimc}] = {{0}};
1476-
expression_kernel(Afc{coef_args});
1477-
for ({IntType_c} i = 0; i < {dimf}; i++)
1478-
for ({IntType_c} j = 0; j < {dimc}; j++)
1479-
Rc[j] += Afc[i*{dimc} + j] * Rf[i] * w[i];
1480-
}}
1481-
"""
1482-
restrict_kernel = op2.Kernel(
1483-
restrict_code,
1484-
"restriction",
1485-
requires_zeroed_output_arguments=True,
1486-
events=matrix_kernel.events,
1487-
)
1488-
return cache.setdefault(key, (prolong_kernel, restrict_kernel, coefficients))
1489-
14901460
def multTranspose(self, mat, rf, rc):
14911461
"""
14921462
Implement restriction: restrict residual on fine grid rf to coarse grid rc.
@@ -1566,61 +1536,15 @@ def getNestSubMatrix(self, i, j):
15661536
return None
15671537

15681538

1569-
def prolongation_matrix_aij(P1, Pk, P1_bcs=[], Pk_bcs=[]):
1570-
if isinstance(P1, firedrake.Function):
1571-
P1 = P1.function_space()
1572-
if isinstance(Pk, firedrake.Function):
1573-
Pk = Pk.function_space()
1574-
sp = op2.Sparsity((Pk.dof_dset,
1575-
P1.dof_dset),
1576-
{(i, j): [(rmap, cmap, None)]
1577-
for i, rmap in enumerate(Pk.cell_node_map())
1578-
for j, cmap in enumerate(P1.cell_node_map())
1579-
if i == j})
1580-
mat = op2.Mat(sp, PETSc.ScalarType)
1581-
mesh = Pk.mesh()
1582-
1583-
fele = Pk.ufl_element()
1584-
if type(fele) is finat.ufl.MixedElement:
1585-
for i in range(fele.num_sub_elements):
1586-
Pk_bcs_i = [bc for bc in Pk_bcs if bc.function_space().index == i]
1587-
P1_bcs_i = [bc for bc in P1_bcs if bc.function_space().index == i]
1588-
1589-
rlgmap, clgmap = mat[i, i].local_to_global_maps
1590-
rlgmap = Pk.sub(i).local_to_global_map(Pk_bcs_i, lgmap=rlgmap)
1591-
clgmap = P1.sub(i).local_to_global_map(P1_bcs_i, lgmap=clgmap)
1592-
unroll = any(bc.function_space().component is not None
1593-
for bc in chain(Pk_bcs_i, P1_bcs_i) if bc is not None)
1594-
matarg = mat[i, i](op2.WRITE, (Pk.sub(i).cell_node_map(), P1.sub(i).cell_node_map()),
1595-
lgmaps=((rlgmap, clgmap), ), unroll_map=unroll)
1596-
expr = firedrake.TrialFunction(P1.sub(i))
1597-
kernel, coefficients = prolongation_transfer_kernel_action(Pk.sub(i), expr)
1598-
parloop_args = [kernel, mesh.topology.unique().cell_set, matarg]
1599-
for coefficient in coefficients:
1600-
m_ = coefficient.cell_node_map()
1601-
parloop_args.append(coefficient.dat(op2.READ, m_))
1602-
1603-
op2.par_loop(*parloop_args)
1604-
1605-
else:
1606-
rlgmap, clgmap = mat.local_to_global_maps
1607-
rlgmap = Pk.local_to_global_map(Pk_bcs, lgmap=rlgmap)
1608-
clgmap = P1.local_to_global_map(P1_bcs, lgmap=clgmap)
1609-
unroll = any(bc.function_space().component is not None
1610-
for bc in chain(Pk_bcs, P1_bcs) if bc is not None)
1611-
matarg = mat(op2.WRITE, (Pk.cell_node_map(), P1.cell_node_map()),
1612-
lgmaps=((rlgmap, clgmap), ), unroll_map=unroll)
1613-
expr = firedrake.TrialFunction(P1)
1614-
kernel, coefficients = prolongation_transfer_kernel_action(Pk, expr)
1615-
parloop_args = [kernel, mesh.topology.unique().cell_set, matarg]
1616-
for coefficient in coefficients:
1617-
m_ = coefficient.cell_node_map()
1618-
parloop_args.append(coefficient.dat(op2.READ, m_))
1619-
1620-
op2.par_loop(*parloop_args)
1621-
1622-
mat.assemble()
1623-
return mat.handle
1539+
def prolongation_matrix_aij(Vc, Vf, Vc_bcs=(), Vf_bcs=()):
1540+
if isinstance(Vf, firedrake.Function):
1541+
Vf = Vf.function_space()
1542+
if isinstance(Vc, firedrake.Function):
1543+
Vc = Vc.function_space()
1544+
bcs = Vc_bcs + Vf_bcs
1545+
interp = firedrake.interpolate(firedrake.TrialFunction(Vc), Vf)
1546+
mat = firedrake.assemble(interp, bcs=bcs)
1547+
return mat.petscmat
16241548

16251549

16261550
def prolongation_matrix_matfree(Vc, Vf, Vc_bcs=[], Vf_bcs=[]):

tests/firedrake/multigrid/test_p_multigrid.py

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,7 @@ def test_p_multigrid_scalar(mesh, mat_type, restrict):
177177
F = inner(grad(u), grad(v))*dx - inner(f, v)*dx
178178

179179
relax = {"ksp_type": "chebyshev",
180-
"ksp_monitor_true_residual": None,
181-
"ksp_norm_type": "unpreconditioned",
180+
"ksp_convergence_test": "skip",
182181
"ksp_max_it": 3,
183182
"pc_type": "jacobi"}
184183

@@ -188,20 +187,12 @@ def test_p_multigrid_scalar(mesh, mat_type, restrict):
188187
"ksp_monitor_true_residual": None,
189188
"pc_type": "python",
190189
"pc_python_type": "firedrake.PMGPC",
191-
"pmg_pc_mg_type": "multiplicative",
192190
"pmg_mg_levels": relax,
193191
"pmg_mg_levels_transfer_mat_type": mat_type,
194-
"pmg_mg_coarse_ksp_type": "richardson",
195-
"pmg_mg_coarse_ksp_max_it": 1,
196-
"pmg_mg_coarse_ksp_norm_type": "unpreconditioned",
197-
"pmg_mg_coarse_ksp_monitor": None,
192+
"pmg_mg_coarse_ksp_type": "preonly",
198193
"pmg_mg_coarse_pc_type": "mg",
199-
"pmg_mg_coarse_pc_mg_type": "multiplicative",
200194
"pmg_mg_coarse_mg_levels": relax,
201-
"pmg_mg_coarse_mg_coarse_ksp_type": "richardson",
202-
"pmg_mg_coarse_mg_coarse_ksp_max_it": 1,
203-
"pmg_mg_coarse_mg_coarse_ksp_norm_type": "unpreconditioned",
204-
"pmg_mg_coarse_mg_coarse_ksp_monitor": None,
195+
"pmg_mg_coarse_mg_coarse_ksp_type": "preonly",
205196
"pmg_mg_coarse_mg_coarse_pc_type": "gamg",
206197
"pmg_mg_coarse_mg_coarse_pc_gamg_threshold": 0}
207198
problem = NonlinearVariationalProblem(F, u, bcs, restrict=restrict)
@@ -225,8 +216,6 @@ def test_p_multigrid_nonlinear_scalar(mesh, mat_type):
225216
F = inner((Constant(1.0) + u**2) * grad(u), grad(v))*dx - inner(f, v)*dx
226217

227218
relax = {"ksp_type": "chebyshev",
228-
"ksp_monitor_true_residual": None,
229-
"ksp_norm_type": "unpreconditioned",
230219
"ksp_max_it": 3,
231220
"pc_type": "jacobi"}
232221

@@ -236,20 +225,12 @@ def test_p_multigrid_nonlinear_scalar(mesh, mat_type):
236225
"ksp_monitor_true_residual": None,
237226
"pc_type": "python",
238227
"pc_python_type": "firedrake.PMGPC",
239-
"pmg_pc_mg_type": "multiplicative",
240228
"pmg_mg_levels": relax,
241229
"pmg_mg_levels_transfer_mat_type": mat_type,
242-
"pmg_mg_coarse_ksp_type": "richardson",
243-
"pmg_mg_coarse_ksp_max_it": 1,
244-
"pmg_mg_coarse_ksp_norm_type": "unpreconditioned",
245-
"pmg_mg_coarse_ksp_monitor": None,
230+
"pmg_mg_coarse_ksp_type": "preonly",
246231
"pmg_mg_coarse_pc_type": "mg",
247-
"pmg_mg_coarse_pc_mg_type": "multiplicative",
248232
"pmg_mg_coarse_mg_levels": relax,
249-
"pmg_mg_coarse_mg_coarse_ksp_type": "richardson",
250-
"pmg_mg_coarse_mg_coarse_ksp_max_it": 1,
251-
"pmg_mg_coarse_mg_coarse_ksp_norm_type": "unpreconditioned",
252-
"pmg_mg_coarse_mg_coarse_ksp_monitor": None,
233+
"pmg_mg_coarse_mg_coarse_ksp_type": "preonly",
253234
"pmg_mg_coarse_mg_coarse_pc_type": "gamg",
254235
"pmg_mg_coarse_mg_coarse_pc_gamg_threshold": 0}
255236
problem = NonlinearVariationalProblem(F, u, bcs)
@@ -295,14 +276,9 @@ def test_p_multigrid_vector():
295276
"pc_python_type": "firedrake.PMGPC",
296277
"pmg_pc_mg_type": "full",
297278
"pmg_mg_levels_ksp_type": "chebyshev",
298-
"pmg_mg_levels_ksp_monitor_true_residual": None,
299-
"pmg_mg_levels_ksp_norm_type": "unpreconditioned",
300279
"pmg_mg_levels_ksp_max_it": 2,
301280
"pmg_mg_levels_pc_type": "pbjacobi",
302-
"pmg_mg_coarse_ksp_type": "richardson",
303-
"pmg_mg_coarse_ksp_max_it": 1,
304-
"pmg_mg_coarse_ksp_norm_type": "unpreconditioned",
305-
"pmg_mg_coarse_ksp_monitor": None,
281+
"pmg_mg_coarse_ksp_type": "preonly",
306282
"pmg_mg_coarse_pc_type": "lu"}
307283
problem = NonlinearVariationalProblem(F, u, bcs)
308284
solver = NonlinearVariationalSolver(problem, solver_parameters=sp)
@@ -328,16 +304,12 @@ def test_p_multigrid_mixed(mat_type):
328304

329305
relax = {"transfer_mat_type": mat_type,
330306
"ksp_type": "chebyshev",
331-
"ksp_monitor_true_residual": None,
332-
"ksp_norm_type": "unpreconditioned",
307+
"ksp_convergence_test": "skip",
333308
"ksp_max_it": 3,
334309
"pc_type": "jacobi"}
335310

336311
coarse = {"mat_type": "aij", # This circumvents the need for AssembledPC
337-
"ksp_type": "richardson",
338-
"ksp_max_it": 1,
339-
"ksp_norm_type": "unpreconditioned",
340-
"ksp_monitor": None,
312+
"ksp_type": "preonly",
341313
"pc_type": "cholesky",
342314
"pc_factor_shift_type": "nonzero",
343315
"pc_factor_shift_amount": 1E-10}
@@ -350,7 +322,6 @@ def test_p_multigrid_mixed(mat_type):
350322
"pc_type": "python",
351323
"pc_python_type": "firedrake.PMGPC",
352324
"mat_type": mat_type,
353-
"pmg_pc_mg_type": "multiplicative",
354325
"pmg_mg_levels": relax,
355326
"pmg_mg_coarse": coarse}
356327

@@ -424,13 +395,10 @@ def test_p_fas_scalar():
424395
coarse = {
425396
"mat_type": "aij",
426397
"ksp_type": "preonly",
427-
"ksp_norm_type": None,
428398
"pc_type": "cholesky"}
429399

430400
relax = {
431401
"ksp_type": "chebyshev",
432-
"ksp_monitor_true_residual": None,
433-
"ksp_norm_type": "unpreconditioned",
434402
"pc_type": "jacobi"}
435403

436404
pmg = {
@@ -512,12 +480,10 @@ def test_p_fas_nonlinear_scalar():
512480

513481
coarse = {
514482
"ksp_type": "preonly",
515-
"ksp_norm_type": None,
516483
"pc_type": "cholesky"}
517484

518485
relax = {
519486
"ksp_type": "chebyshev",
520-
"ksp_norm_type": "unpreconditioned",
521487
"ksp_chebyshev_esteig": "0.75,0.25,0,1",
522488
"ksp_max_it": 3,
523489
"pc_type": "jacobi"}
@@ -531,7 +497,6 @@ def test_p_fas_nonlinear_scalar():
531497
"ksp_norm_type": "unpreconditioned",
532498
"pc_type": "python",
533499
"pc_python_type": "firedrake.PMGPC",
534-
"pmg_pc_mg_type": "multiplicative",
535500
"pmg_mg_levels": relax,
536501
"pmg_mg_levels_transfer_mat_type": mat_type,
537502
"pmg_mg_coarse": coarse}

0 commit comments

Comments
 (0)