Skip to content

Commit 03758e0

Browse files
committed
io: cache on mesh
1 parent fdde737 commit 03758e0

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

firedrake/checkpointing.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,6 @@ def __init__(self, filename, mode, comm=COMM_WORLD):
533533
self.commkey = self._comm.py2f()
534534
assert self.commkey != MPI.COMM_NULL.py2f()
535535
self._function_spaces = {}
536-
self._function_load_utils = {}
537536
if mode in [PETSc.Viewer.FileMode.WRITE, PETSc.Viewer.FileMode.W, "w"]:
538537
version = CheckpointFile.latest_version
539538
self.set_attr_byte_string("/", "dmplex_storage_version", version)
@@ -1058,11 +1057,8 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
10581057
cell = base_tmesh.ufl_cell()
10591058
element = finat.ufl.VectorElement("DP" if cell.is_simplex else "DQ", cell, 0, dim=2)
10601059
_ = self._load_function_space_topology(base_tmesh, element)
1061-
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
1062-
base_tmesh._distribution_name,
1063-
base_tmesh._permutation_name)
10641060
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
1065-
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
1061+
_, _, lsf = self._shared_data_cache(base_tmesh)[sd_key]
10661062
nroots, _, _ = lsf.getGraph()
10671063
layers_a = np.empty(nroots, dtype=utils.IntType)
10681064
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
@@ -1121,11 +1117,8 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
11211117
cell = tmesh.ufl_cell()
11221118
element = finat.ufl.FiniteElement("DP" if cell.is_simplex else "DQ", cell, 0)
11231119
cell_orientations_tV = self._load_function_space_topology(tmesh, element)
1124-
tmesh_key = self._generate_mesh_key_from_names(tmesh.name,
1125-
tmesh._distribution_name,
1126-
tmesh._permutation_name)
11271120
sd_key = self._get_shared_data_key_for_checkpointing(tmesh, element)
1128-
_, _, lsf = self._function_load_utils[tmesh_key + sd_key]
1121+
_, _, lsf = self._shared_data_cache(tmesh)[sd_key]
11291122
nroots, _, _ = lsf.getGraph()
11301123
cell_orientations_a = np.empty(nroots, dtype=utils.IntType)
11311124
cell_orientations_a_iset = PETSc.IS().createGeneral(cell_orientations_a, comm=self._comm)
@@ -1270,11 +1263,8 @@ def _load_function_space(self, mesh, name):
12701263
def _load_function_space_topology(self, tmesh, element):
12711264
if element.family() == "Real":
12721265
return impl.RealFunctionSpace(tmesh, element, "unused_name")
1273-
tmesh_key = self._generate_mesh_key_from_names(tmesh.name,
1274-
tmesh._distribution_name,
1275-
tmesh._permutation_name)
12761266
sd_key = self._get_shared_data_key_for_checkpointing(tmesh, element)
1277-
if tmesh_key + sd_key not in self._function_load_utils:
1267+
if sd_key not in self._shared_data_cache(tmesh):
12781268
topology_dm = tmesh.topology_dm
12791269
dm = PETSc.DMShell().create(comm=tmesh._comm)
12801270
dm.setName(self._get_dm_name_for_checkpointing(tmesh, element))
@@ -1293,7 +1283,7 @@ def _load_function_space_topology(self, tmesh, element):
12931283
if dm.getSection() is not cached_section:
12941284
# The same section has already been cached.
12951285
dm.setSection(cached_section)
1296-
self._function_load_utils[tmesh_key + sd_key] = (dm, gsf, lsf)
1286+
self._shared_data_cache(tmesh)[sd_key] = (dm, gsf, lsf)
12971287
return impl.FunctionSpace(tmesh, element)
12981288

12991289
@PETSc.Log.EventDecorator("LoadFunction")
@@ -1379,10 +1369,7 @@ def _load_function_topology(self, tmesh, element, tf_name, idx=None):
13791369
with tf.dat.vec_wo as vec:
13801370
vec.setName(tf_name)
13811371
sd_key = self._get_shared_data_key_for_checkpointing(tmesh, element)
1382-
tmesh_key = self._generate_mesh_key_from_names(tmesh.name,
1383-
tmesh._distribution_name,
1384-
tmesh._permutation_name)
1385-
dm, sf, _ = self._function_load_utils[tmesh_key + sd_key]
1372+
dm, sf, _ = self._shared_data_cache(tmesh)[sd_key]
13861373
base_tmesh_name = topology_dm.getName()
13871374
topology_dm.setName(tmesh.name)
13881375
topology_dm.globalVectorLoad(self.viewer, dm, sf, vec)
@@ -1460,6 +1447,14 @@ def _get_dm_name_for_checkpointing(self, tmesh, ufl_element):
14601447
sd_key = self._get_shared_data_key_for_checkpointing(tmesh, ufl_element)
14611448
return self._generate_dm_name(*sd_key)
14621449

1450+
def _shared_data_cache(self, tmesh):
1451+
# Cache gsf/lsf that push forward the on-disk DoF vector to the in-memory global/local vectors.
1452+
# Cache on mesh, not on self, so that they can be used across multiple CheckpointFile instances
1453+
# (at the cost of longer life).
1454+
if not hasattr(tmesh, "_shared_data_cache"):
1455+
raise RuntimeError(f"_shared_data_cache not on {tmesh}")
1456+
return tmesh._shared_data_cache["checkpointfile_" + self.filename]
1457+
14631458
def _path_to_topologies(self):
14641459
return "topologies"
14651460

0 commit comments

Comments
 (0)