@@ -533,7 +533,6 @@ def __init__(self, filename, mode, comm=COMM_WORLD):
533533 self .commkey = self ._comm .py2f ()
534534 assert self .commkey != MPI .COMM_NULL .py2f ()
535535 self ._function_spaces = {}
536- self ._function_load_utils = {}
537536 if mode in [PETSc .Viewer .FileMode .WRITE , PETSc .Viewer .FileMode .W , "w" ]:
538537 version = CheckpointFile .latest_version
539538 self .set_attr_byte_string ("/" , "dmplex_storage_version" , version )
@@ -1058,11 +1057,8 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
10581057 cell = base_tmesh .ufl_cell ()
10591058 element = finat .ufl .VectorElement ("DP" if cell .is_simplex else "DQ" , cell , 0 , dim = 2 )
10601059 _ = self ._load_function_space_topology (base_tmesh , element )
1061- base_tmesh_key = self ._generate_mesh_key_from_names (base_tmesh .name ,
1062- base_tmesh ._distribution_name ,
1063- base_tmesh ._permutation_name )
10641060 sd_key = self ._get_shared_data_key_for_checkpointing (base_tmesh , element )
1065- _ , _ , lsf = self ._function_load_utils [ base_tmesh_key + sd_key ]
1061+ _ , _ , lsf = self ._shared_data_cache ( base_tmesh )[ sd_key ]
10661062 nroots , _ , _ = lsf .getGraph ()
10671063 layers_a = np .empty (nroots , dtype = utils .IntType )
10681064 layers_a_iset = PETSc .IS ().createGeneral (layers_a , comm = self ._comm )
@@ -1121,11 +1117,8 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
11211117 cell = tmesh .ufl_cell ()
11221118 element = finat .ufl .FiniteElement ("DP" if cell .is_simplex else "DQ" , cell , 0 )
11231119 cell_orientations_tV = self ._load_function_space_topology (tmesh , element )
1124- tmesh_key = self ._generate_mesh_key_from_names (tmesh .name ,
1125- tmesh ._distribution_name ,
1126- tmesh ._permutation_name )
11271120 sd_key = self ._get_shared_data_key_for_checkpointing (tmesh , element )
1128- _ , _ , lsf = self ._function_load_utils [ tmesh_key + sd_key ]
1121+ _ , _ , lsf = self ._shared_data_cache ( tmesh )[ sd_key ]
11291122 nroots , _ , _ = lsf .getGraph ()
11301123 cell_orientations_a = np .empty (nroots , dtype = utils .IntType )
11311124 cell_orientations_a_iset = PETSc .IS ().createGeneral (cell_orientations_a , comm = self ._comm )
@@ -1270,11 +1263,8 @@ def _load_function_space(self, mesh, name):
12701263 def _load_function_space_topology (self , tmesh , element ):
12711264 if element .family () == "Real" :
12721265 return impl .RealFunctionSpace (tmesh , element , "unused_name" )
1273- tmesh_key = self ._generate_mesh_key_from_names (tmesh .name ,
1274- tmesh ._distribution_name ,
1275- tmesh ._permutation_name )
12761266 sd_key = self ._get_shared_data_key_for_checkpointing (tmesh , element )
1277- if tmesh_key + sd_key not in self ._function_load_utils :
1267+ if sd_key not in self ._shared_data_cache ( tmesh ) :
12781268 topology_dm = tmesh .topology_dm
12791269 dm = PETSc .DMShell ().create (comm = tmesh ._comm )
12801270 dm .setName (self ._get_dm_name_for_checkpointing (tmesh , element ))
@@ -1293,7 +1283,7 @@ def _load_function_space_topology(self, tmesh, element):
12931283 if dm .getSection () is not cached_section :
12941284 # The same section has already been cached.
12951285 dm .setSection (cached_section )
1296- self ._function_load_utils [ tmesh_key + sd_key ] = (dm , gsf , lsf )
1286+ self ._shared_data_cache ( tmesh )[ sd_key ] = (dm , gsf , lsf )
12971287 return impl .FunctionSpace (tmesh , element )
12981288
12991289 @PETSc .Log .EventDecorator ("LoadFunction" )
@@ -1379,10 +1369,7 @@ def _load_function_topology(self, tmesh, element, tf_name, idx=None):
13791369 with tf .dat .vec_wo as vec :
13801370 vec .setName (tf_name )
13811371 sd_key = self ._get_shared_data_key_for_checkpointing (tmesh , element )
1382- tmesh_key = self ._generate_mesh_key_from_names (tmesh .name ,
1383- tmesh ._distribution_name ,
1384- tmesh ._permutation_name )
1385- dm , sf , _ = self ._function_load_utils [tmesh_key + sd_key ]
1372+ dm , sf , _ = self ._shared_data_cache (tmesh )[sd_key ]
13861373 base_tmesh_name = topology_dm .getName ()
13871374 topology_dm .setName (tmesh .name )
13881375 topology_dm .globalVectorLoad (self .viewer , dm , sf , vec )
@@ -1460,6 +1447,14 @@ def _get_dm_name_for_checkpointing(self, tmesh, ufl_element):
14601447 sd_key = self ._get_shared_data_key_for_checkpointing (tmesh , ufl_element )
14611448 return self ._generate_dm_name (* sd_key )
14621449
1450+ def _shared_data_cache (self , tmesh ):
1451+ # Cache gsf/lsf that push forward the on-disk DoF vector to the in-memory global/local vectors.
1452+ # Cache on mesh, not on self, so that they can be used across multiple CheckpointFile instances
1453+ # (at the cost of longer life).
1454+ if not hasattr (tmesh , "_shared_data_cache" ):
1455+ raise RuntimeError (f"_shared_data_cache not on { tmesh } " )
1456+ return tmesh ._shared_data_cache ["checkpointfile_" + self .filename ]
1457+
14631458 def _path_to_topologies (self ):
14641459 return "topologies"
14651460
0 commit comments