Skip to content

Commit 22c1a81

Browse files
committed
Fix MG caches
1 parent e21254e commit 22c1a81

File tree

4 files changed

+95
-10
lines changed

4 files changed

+95
-10
lines changed

firedrake/mg/embedded.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from enum import IntEnum
66
from firedrake.petsc import PETSc
77
from firedrake.embedding import get_embedding_dg_element
8+
from .utils import get_level
89

910

1011
__all__ = ("TransferManager", )
@@ -65,7 +66,6 @@ def __init__(self, *, native_transfers=None, use_averaging=True):
6566
"""
6667
self.native_transfers = native_transfers or {}
6768
self.use_averaging = use_averaging
68-
self.caches = {}
6969

7070
def is_native(self, element, op):
7171
if element in self.native_transfers.keys():
@@ -87,14 +87,17 @@ def _native_transfer(self, element, op):
8787
return None
8888

8989
def cache(self, V):
90+
mh, _ = get_level(V.mesh())
91+
caches = mh._shared_data_cache["transfer_manager_cache"]
9092
key = (V.ufl_element(), V.value_shape, V.boundary_set)
9193
try:
92-
return self.caches[key]
94+
return caches[key]
9395
except KeyError:
94-
return self.caches.setdefault(key, TransferManager.Cache(*key[:2]))
96+
return caches.setdefault(key, TransferManager.Cache(*key[:2]))
9597

9698
def cache_key(self, V):
97-
return (V.dim(),)
99+
_, level = get_level(V.mesh())
100+
return (level,)
98101

99102
def V_dof_weights(self, V):
100103
"""Dof weights for averaging projection.
@@ -143,7 +146,7 @@ def DG_inv_mass(self, DG):
143146
:returns: A PETSc Mat.
144147
"""
145148
cache = self.cache(DG)
146-
key = DG.dim()
149+
key = self.cache_key(DG)
147150
try:
148151
return cache._DG_inv_mass[key]
149152
except KeyError:

firedrake/preconditioners/pmg.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,6 @@ class StandaloneInterpolationMatrix(object):
11881188
"""
11891189

11901190
_cache_kernels = {}
1191-
_cache_work = {}
11921191

11931192
def __init__(self, Vc, Vf, Vc_bcs, Vf_bcs):
11941193
self.uc = self.work_function(Vc)
@@ -1209,14 +1208,16 @@ def __init__(self, Vc, Vf, Vc_bcs, Vf_bcs):
12091208
self.Vc_bcs = [bc.reconstruct(V=self.Vc, g=0) for bc in self.Vc_bcs]
12101209
self.Vf_bcs = [bc.reconstruct(V=self.Vf, g=0) for bc in self.Vf_bcs]
12111210

1212-
def work_function(self, V):
1211+
@staticmethod
1212+
def work_function(V):
12131213
if isinstance(V, firedrake.Function):
12141214
return V
1215-
key = (V.ufl_element(), V.mesh(), V.boundary_set)
1215+
cache = V.mesh()._shared_data_cache["pmg_work_function"]
1216+
key = (V.ufl_element(), V.value_shape, V.boundary_set)
12161217
try:
1217-
return self._cache_work[key]
1218+
return cache[key]
12181219
except KeyError:
1219-
return self._cache_work.setdefault(key, firedrake.Function(V))
1220+
return cache.setdefault(key, firedrake.Function(V))
12201221

12211222
@cached_property
12221223
def _weight(self):

tests/firedrake/multigrid/test_p_multigrid.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ def test_reconstruct_degree(tp_mesh, mixed_family):
7373
assert e == PMGPC.reconstruct_degree(elist[0], degree)
7474

7575

76+
def test_work_function_cache(tp_mesh):
77+
from firedrake.preconditioners.pmg import StandaloneInterpolationMatrix
78+
79+
V1 = FunctionSpace(tp_mesh, "Lagrange", 1)
80+
V2 = FunctionSpace(tp_mesh, "Lagrange", 1)
81+
assert V1 is not V2
82+
w1 = StandaloneInterpolationMatrix.work_function(V1)
83+
w2 = StandaloneInterpolationMatrix.work_function(V2)
84+
assert w1 is w2
85+
86+
7687
@pytest.mark.parametrize("family", ["Q", "NCE", "NCF", "DQ"])
7788
def test_prolong_basic(tp_mesh, family):
7889
""" Interpolate a constant function between low-order and high-order spaces

tests/firedrake/multigrid/test_transfer_manager.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,73 @@ def test_transfer_manager_dat_version_cache(action, transfer_op, spaces):
131131

132132
else:
133133
raise ValueError(f"Unrecognized action {action}")
134+
135+
136+
@pytest.fixture
137+
def DG_spaces(hierarchy):
138+
return tuple(VectorFunctionSpace(mesh, "DG", 1) for mesh in hierarchy)
139+
140+
141+
@pytest.fixture
142+
def RT_spaces(hierarchy):
143+
return tuple(RestrictedFunctionSpace(FunctionSpace(mesh, "RT", 1), ["on_boundary"]) for mesh in hierarchy)
144+
145+
146+
@pytest.mark.parametrize("action", [
147+
"DG_work_function",
148+
"DG_work_cofunction",
149+
"V_dof_weights",
150+
"work_vec",
151+
"V_DG_mass",
152+
"DG_inv_mass",
153+
"V_approx_inv_mass",
154+
"V_inv_mass_ksp",
155+
])
156+
def test_transfer_manager_cache(action, DG_spaces, RT_spaces):
157+
V1 = RT_spaces[0]
158+
DG1 = DG_spaces[0]
159+
160+
V2 = V1.reconstruct(name="V2")
161+
DG2 = DG1.reconstruct(name="DG2")
162+
assert V2 is not V1
163+
assert DG2 is not DG1
164+
165+
if action == "DG_work_function":
166+
w1 = transfer.DG_work(DG1)
167+
w2 = transfer.DG_work(DG2)
168+
assert w1 is w2
169+
170+
elif action == "DG_work_cofunction":
171+
w1 = transfer.DG_work(DG1.dual())
172+
w2 = transfer.DG_work(DG2.dual())
173+
assert w1 is w2
174+
175+
elif action == "V_dof_weights":
176+
w1 = transfer.V_dof_weights(V1)
177+
w2 = transfer.V_dof_weights(V2)
178+
assert w1 is w2
179+
180+
elif action == "work_vec":
181+
w1 = transfer.work_vec(V1)
182+
w2 = transfer.work_vec(V2)
183+
assert w1 is w2
184+
185+
elif action == "V_DG_mass":
186+
M1 = transfer.V_DG_mass(V1, DG1)
187+
M2 = transfer.V_DG_mass(V2, DG2)
188+
assert M1 is M2
189+
190+
elif action == "DG_inv_mass":
191+
M1 = transfer.DG_inv_mass(DG1)
192+
M2 = transfer.DG_inv_mass(DG2)
193+
assert M1 is M2
194+
195+
elif action == "V_approx_inv_mass":
196+
M1 = transfer.V_approx_inv_mass(V1, DG1)
197+
M2 = transfer.V_approx_inv_mass(V2, DG2)
198+
assert M1 is M2
199+
200+
elif action == "V_inv_mass_ksp":
201+
ksp1 = transfer.V_inv_mass_ksp(V1)
202+
ksp2 = transfer.V_inv_mass_ksp(V2)
203+
assert ksp1 is ksp2

0 commit comments

Comments
 (0)