Skip to content

Commit 2e52374

Browse files
committed
Fix multigrid caches
1 parent e4e6817 commit 2e52374

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

firedrake/mg/embedded.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from enum import IntEnum
66
from firedrake.petsc import PETSc
77
from firedrake.embedding import get_embedding_dg_element
8+
from .utils import get_level
89

910

1011
__all__ = ("TransferManager", )
@@ -40,7 +41,7 @@ class Cache(object):
4041
"""A caching object for work vectors and matrices.
4142
4243
:arg element: The element to use for the caching."""
43-
def __init__(self, ufl_element, value_shape):
44+
def __init__(self, ufl_element, value_shape, boundary_set):
4445
self.embedding_element = get_embedding_dg_element(ufl_element, value_shape)
4546
self._dat_versions = {}
4647
self._V_DG_mass = {}
@@ -65,7 +66,6 @@ def __init__(self, *, native_transfers=None, use_averaging=True):
6566
"""
6667
self.native_transfers = native_transfers or {}
6768
self.use_averaging = use_averaging
68-
self.caches = {}
6969

7070
def is_native(self, element, op):
7171
if element in self.native_transfers.keys():
@@ -87,14 +87,17 @@ def _native_transfer(self, element, op):
8787
return None
8888

8989
def cache(self, V):
90-
key = (V.ufl_element(), V.value_shape)
90+
mh, _ = get_level(V.mesh())
91+
caches = mh._shared_data_cache["transfer_manager_cache"]
92+
key = (V.ufl_element(), V.value_shape, V.boundary_set)
9193
try:
92-
return self.caches[key]
94+
return caches[key]
9395
except KeyError:
94-
return self.caches.setdefault(key, TransferManager.Cache(*key))
96+
return caches.setdefault(key, TransferManager.Cache(*key))
9597

9698
def cache_key(self, V):
97-
return (V.dim(), V.boundary_set)
99+
_, level = get_level(V.mesh())
100+
return level
98101

99102
def V_dof_weights(self, V):
100103
"""Dof weights for averaging projection.

firedrake/preconditioners/pmg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,6 @@ class StandaloneInterpolationMatrix(object):
11881188
"""
11891189

11901190
_cache_kernels = {}
1191-
_cache_work = {}
11921191

11931192
def __init__(self, Vc, Vf, Vc_bcs, Vf_bcs):
11941193
self.uc = self.work_function(Vc)
@@ -1212,11 +1211,12 @@ def __init__(self, Vc, Vf, Vc_bcs, Vf_bcs):
12121211
def work_function(self, V):
12131212
if isinstance(V, firedrake.Function):
12141213
return V
1215-
key = (V.ufl_element(), V.mesh(), V.boundary_set)
1214+
cache = V.mesh()._shared_data_cache["pmg_work_function"]
1215+
key = (V.ufl_element(), V.value_shape, V.boundary_set)
12161216
try:
1217-
return self._cache_work[key]
1217+
return cache[key]
12181218
except KeyError:
1219-
return self._cache_work.setdefault(key, firedrake.Function(V))
1219+
return cache.setdefault(key, firedrake.Function(V))
12201220

12211221
@cached_property
12221222
def _weight(self):

0 commit comments

Comments
 (0)