Skip to content

Commit

Permalink
Ensure AsyncCheckpointer completion logs on every host instead of jus…
Browse files Browse the repository at this point in the history
…t the leader.

PiperOrigin-RevId: 724466022
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Feb 7, 2025
1 parent a1718d2 commit 08dd65b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
7 changes: 7 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.11.4] - 2025-02-07

### Changed

- Updated orbax-checkpoint PyPI package to exclude tests.

### Fixed

- `AsyncCheckpointer` completion logging, to log on all hosts instead of just
the leader.

## [0.11.3] - 2025-02-06

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ def _on_commit_callback(
'/jax/checkpoint/write/async/total_duration_secs',
total_duration_secs,
)
logging.info(
'Finished asynchronous save in %.2f seconds to %s',
total_duration_secs,
tmpdir.get_final(),
)


def _add_deadline_exceeded_notes(e: jax.errors.JaxRuntimeError):
Expand Down Expand Up @@ -405,7 +400,8 @@ def _callback() -> None:
# Update StepMetadata after the handler save is complete.
# (blocking write)
self._save_step_metadata(tmpdir.get(), custom_metadata=custom_metadata)
logging.info(
logging.vlog(
1,
'[process=%s][thread=%s] Async Save Callback [1/3]: Finalizing'
' Handler: %s on %s',
multihost.process_index(),
Expand All @@ -415,7 +411,8 @@ def _callback() -> None:
)
# Finalize does a final StepMetadata update.
self._handler.finalize(tmpdir.get())
logging.info(
logging.vlog(
1,
'[process=%s][thread=%s] Async Save Callback [2/3]: Running'
' post_finalization_callback: %s on %s',
multihost.process_index(),
Expand All @@ -425,7 +422,8 @@ def _callback() -> None:
)
if self._post_finalization_callback is not None:
self._post_finalization_callback()
logging.info(
logging.vlog(
1,
'[process=%s][thread=%s] Async Save Callback [3/3]: Finalizing'
' checkpoint directory: %s',
multihost.process_index(),
Expand All @@ -436,6 +434,12 @@ def _callback() -> None:
tmpdir,
checkpoint_start_time,
)
logging.info(
'Finished asynchronous save (blocking + background) in %.2f seconds'
' to %s',
time.time() - checkpoint_start_time,
directory,
)

self._async_manager.start_async_commit(
directory,
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# A new PyPI release will be pushed everytime `__version__` is increased.
# Also modify version and date in CHANGELOG.
__version__ = '0.11.3'
__version__ = '0.11.4'


# TODO: b/362813406 - Add latest change timestamp and commit number.
Expand Down

0 comments on commit 08dd65b

Please sign in to comment.