|
8 | 8 | from firedrake.nullspace import VectorSpaceBasis, MixedVectorSpaceBasis |
9 | 9 | from firedrake.solving_utils import _SNESContext |
10 | 10 | from firedrake.tsfc_interface import extract_numbered_coefficients |
11 | | -from firedrake.utils import ScalarType_c, IntType_c, cached_property |
| 11 | +from firedrake.utils import IntType_c, cached_property |
12 | 12 | from finat.element_factory import create_element |
13 | 13 | from tsfc import compile_expression_dual_evaluation |
14 | 14 | from pyop2 import op2 |
@@ -1236,13 +1236,18 @@ def _weight(self): |
1236 | 1236 |
|
1237 | 1237 | @cached_property |
1238 | 1238 | def _kernels(self): |
| 1239 | + from firedrake.interpolation import interpolate, Interpolator |
1239 | 1240 | try: |
1240 | | - prolong = partial(firedrake.assemble, firedrake.interpolate(self.uc, self.Vf), tensor=self.uf) |
1241 | | - prolong() |
1242 | | - self.rf = firedrake.Function(self.Vf.dual(), val=self.uf.dat) |
1243 | | - self.rc = firedrake.Function(self.Vc.dual(), val=self.uc.dat) |
1244 | | - restrict = partial(firedrake.assemble, firedrake.interpolate(firedrake.TestFunction(self.Vc), self.rf), tensor=self.rc) |
1245 | | - except NotImplementedError: |
| 1241 | + assert self.Vf.ufl_element().mapping() == self.Vc.ufl_element().mapping() |
| 1242 | + P = Interpolator(interpolate(self.uc, self.Vf), self.Vf) |
| 1243 | + prolong = partial(P.assemble, tensor=self.uf) |
| 1244 | + |
| 1245 | + rf = firedrake.Function(self.Vf.dual(), val=self.uf.dat) |
| 1246 | + rc = firedrake.Function(self.Vc.dual(), val=self.uc.dat) |
| 1247 | + vc = firedrake.TestFunction(self.Vc) |
| 1248 | + R = Interpolator(interpolate(vc, rf), self.Vf) |
| 1249 | + restrict = partial(R.assemble, tensor=rc) |
| 1250 | + except (AttributeError, AssertionError, NotImplementedError): |
1246 | 1251 | # We generate custom prolongation and restriction kernels because |
1247 | 1252 | # dual evaluation of EnrichedElement is not yet implemented in FInAT |
1248 | 1253 | uf_map = get_permuted_map(self.Vf) |
@@ -1439,49 +1444,6 @@ def make_blas_kernels(self, Vf, Vc): |
1439 | 1444 | ldargs=BLASLAPACK_LIB.split(), requires_zeroed_output_arguments=True) |
1440 | 1445 | return cache.setdefault(key, (prolong_kernel, restrict_kernel, coefficients)) |
1441 | 1446 |
|
1442 | | - def make_kernels(self, Vf, Vc): |
1443 | | - """ |
1444 | | - Interpolation and restriction kernels between arbitrary elements. |
1445 | | -
|
1446 | | - This is temporary while we wait for dual evaluation in FInAT. |
1447 | | - """ |
1448 | | - cache = self._cache_kernels |
1449 | | - key = (Vf.ufl_element(), Vc.ufl_element()) |
1450 | | - try: |
1451 | | - return cache[key] |
1452 | | - except KeyError: |
1453 | | - pass |
1454 | | - prolong_kernel, _ = prolongation_transfer_kernel_action(Vf, self.uc) |
1455 | | - matrix_kernel, coefficients = prolongation_transfer_kernel_action(Vf, firedrake.TrialFunction(Vc)) |
1456 | | - |
1457 | | - # The way we transpose the prolongation kernel is suboptimal. |
1458 | | - # A local matrix is generated each time the kernel is executed. |
1459 | | - element_kernel = cache_generate_code(matrix_kernel, Vf._comm) |
1460 | | - element_kernel = element_kernel.replace("void expression_kernel", "static void expression_kernel") |
1461 | | - coef_args = "".join([", c%d" % i for i in range(len(coefficients))]) |
1462 | | - coef_decl = "".join([", const %s *restrict c%d" % (ScalarType_c, i) for i in range(len(coefficients))]) |
1463 | | - dimc = Vc.finat_element.space_dimension() * Vc.block_size |
1464 | | - dimf = Vf.finat_element.space_dimension() * Vf.block_size |
1465 | | - restrict_code = f""" |
1466 | | - {element_kernel} |
1467 | | -
|
1468 | | - void restriction({ScalarType_c} *restrict Rc, const {ScalarType_c} *restrict Rf, const {ScalarType_c} *restrict w{coef_decl}) |
1469 | | - {{ |
1470 | | - {ScalarType_c} Afc[{dimf}*{dimc}] = {{0}}; |
1471 | | - expression_kernel(Afc{coef_args}); |
1472 | | - for ({IntType_c} i = 0; i < {dimf}; i++) |
1473 | | - for ({IntType_c} j = 0; j < {dimc}; j++) |
1474 | | - Rc[j] += Afc[i*{dimc} + j] * Rf[i] * w[i]; |
1475 | | - }} |
1476 | | - """ |
1477 | | - restrict_kernel = op2.Kernel( |
1478 | | - restrict_code, |
1479 | | - "restriction", |
1480 | | - requires_zeroed_output_arguments=True, |
1481 | | - events=matrix_kernel.events, |
1482 | | - ) |
1483 | | - return cache.setdefault(key, (prolong_kernel, restrict_kernel, coefficients)) |
1484 | | - |
1485 | 1447 | def multTranspose(self, mat, rf, rc): |
1486 | 1448 | """ |
1487 | 1449 | Implement restriction: restrict residual on fine grid rf to coarse grid rc. |
|
0 commit comments