From 08dd65b78776ca3f2d3bbc20839149ad00ae6d5d Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Fri, 7 Feb 2025 14:10:41 -0800 Subject: [PATCH] Ensure AsyncCheckpointer completion logs on every host instead of just the leader. PiperOrigin-RevId: 724466022 --- checkpoint/CHANGELOG.md | 7 +++++++ .../_src/checkpointers/async_checkpointer.py | 20 +++++++++++-------- checkpoint/orbax/checkpoint/version.py | 2 +- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 13de6600..bfae0ca6 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,10 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.11.4] - 2025-02-07 + ### Changed - Updated orbax-checkpoint PyPI package to exclude tests. +### Fixed + +- `AsyncCheckpointer` completion logging, to log on all hosts instead of just +the leader. + ## [0.11.3] - 2025-02-06 ### Changed diff --git a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py index 58b01302..d9d4f78c 100644 --- a/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py +++ b/checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py @@ -58,11 +58,6 @@ def _on_commit_callback( '/jax/checkpoint/write/async/total_duration_secs', total_duration_secs, ) - logging.info( - 'Finished asynchronous save in %.2f seconds to %s', - total_duration_secs, - tmpdir.get_final(), - ) def _add_deadline_exceeded_notes(e: jax.errors.JaxRuntimeError): @@ -405,7 +400,8 @@ def _callback() -> None: # Update StepMetadata after the handler save is complete. # (blocking write) self._save_step_metadata(tmpdir.get(), custom_metadata=custom_metadata) - logging.info( + logging.vlog( + 1, '[process=%s][thread=%s] Async Save Callback [1/3]: Finalizing' ' Handler: %s on %s', multihost.process_index(), @@ -415,7 +411,8 @@ def _callback() -> None: ) # Finalize does a final StepMetadata update. self._handler.finalize(tmpdir.get()) - logging.info( + logging.vlog( + 1, '[process=%s][thread=%s] Async Save Callback [2/3]: Running' ' post_finalization_callback: %s on %s', multihost.process_index(), @@ -425,7 +422,8 @@ def _callback() -> None: ) if self._post_finalization_callback is not None: self._post_finalization_callback() - logging.info( + logging.vlog( + 1, '[process=%s][thread=%s] Async Save Callback [3/3]: Finalizing' ' checkpoint directory: %s', multihost.process_index(), @@ -436,6 +434,12 @@ def _callback() -> None: tmpdir, checkpoint_start_time, ) + logging.info( + 'Finished asynchronous save (blocking + background) in %.2f seconds' + ' to %s', + time.time() - checkpoint_start_time, + directory, + ) self._async_manager.start_async_commit( directory, diff --git a/checkpoint/orbax/checkpoint/version.py b/checkpoint/orbax/checkpoint/version.py index 5c127a9f..ba6a7123 100644 --- a/checkpoint/orbax/checkpoint/version.py +++ b/checkpoint/orbax/checkpoint/version.py @@ -16,7 +16,7 @@ # A new PyPI release will be pushed everytime `__version__` is increased. # Also modify version and date in CHANGELOG. -__version__ = '0.11.3' +__version__ = '0.11.4' # TODO: b/362813406 - Add latest change timestamp and commit number.