Skip to content

Commit

Permalink
Add save barrier to ensure checkpoint manager and checkpointers save …
Browse files Browse the repository at this point in the history
…start is synchronized on all processes.

PiperOrigin-RevId: 730721643
  • Loading branch information
mridul-sahu authored and Orbax Authors committed Feb 27, 2025
1 parent acec3f3 commit 4c7455f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def save(
*args,
force: bool = False,
custom_metadata: dict[str, Any] | None = None,
**kwargs
**kwargs,
):
"""Saves the given item to the provided directory.
Expand All @@ -498,6 +498,23 @@ def save(
ValueError if the provided directory already exists.
"""
checkpoint_start_time = time.time()
multihost.sync_global_processes(
multihost.unique_barrier_key(
'Checkpointer:save_start',
prefix=self._barrier_sync_key_prefix,
),
processes=self._active_processes,
)
start_sync_duration_secs = time.time() - checkpoint_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/start_sync_duration_secs',
start_sync_duration_secs,
)
logging.vlog(
1,
'Finished async checkpointer save start sync in %.2f seconds',
start_sync_duration_secs,
)
directory = epath.Path(directory)
tmpdir = self.get_temporary_path(directory)
on_commit_callback = self._make_on_commit_callback(
Expand Down
17 changes: 17 additions & 0 deletions checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,23 @@ def save(
ValueError if the provided directory already exists.
"""
checkpoint_start_time = time.time()
multihost.sync_global_processes(
multihost.unique_barrier_key(
'Checkpointer:save_start',
prefix=self._barrier_sync_key_prefix,
),
processes=self._active_processes,
)
start_sync_duration_secs = time.time() - checkpoint_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/start_sync_duration_secs',
start_sync_duration_secs,
)
logging.vlog(
1,
'Finished checkpointer save start sync in %.2f seconds',
start_sync_duration_secs,
)
directory = epath.Path(directory)

jax.monitoring.record_event('/jax/orbax/write/start')
Expand Down
19 changes: 19 additions & 0 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,25 @@ def save(
step_stats.checkpoint_manager_blocking_start_time = time.time()
step_stats.directory = str(self.directory)

checkpoint_start_time = time.time()
multihost.sync_global_processes(
multihost.unique_barrier_key(
'CheckpointManager:save_start',
prefix=self._multiprocessing_options.barrier_sync_key_prefix,
),
processes=self._multiprocessing_options.active_processes,
)
start_sync_duration_secs = time.time() - checkpoint_start_time
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/manager_start_sync_duration_secs',
start_sync_duration_secs,
)
logging.vlog(
1,
'Finished checkpoint manager save start sync in %.2f seconds',
start_sync_duration_secs,
)

if items is None and args is None:
raise ValueError('Must provide `args` for `save`.')
self._default_item.set_if_none(determine_default_item_mode_from_args(args))
Expand Down

0 comments on commit 4c7455f

Please sign in to comment.