Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731878819
  • Loading branch information
niketkumar authored and Orbax Authors committed Feb 27, 2025
1 parent 4a24304 commit 950afa6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- Fix RESOURCE_EXHAUSTED while writing array_metadatas.

### Changed

- Improve `Cannot serialize host local jax.Array` error message.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def __init__(
type_handler_registry
)
)
if self._array_metadata_store:
self._array_metadata_store.set_primary_host(self._primary_host)
self._array_metadata_validator = array_metadata_validator


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,45 @@ def __init__(
self,
path_resolver: PathResolver = PathResolver(),
ser_deser: SerDeserializer = SerDeserializer(),
primary_host: int | None = 0, # None means all hosts are primary hosts.
write_timeout_secs: int = 600, # 10 minutes.
):
self._path_resolver = path_resolver
self._ser_deser = ser_deser
self._primary_host = primary_host
self._write_timeout_secs = write_timeout_secs

def set_primary_host(self, primary_host: int | None) -> None:
"""Sets the primary host."""
self._primary_host = primary_host

async def _maybe_create_base_dir(self, base_dir: epath.Path) -> None:
"""Primary host creates the base directory, rest of the hosts wait."""
if multihost.is_primary_host(self._primary_host):
# primary host creates, rest of the hosts wait.
return await asyncio.to_thread(
base_dir.mkdir, parents=True, exist_ok=True
)

# non-primary host waits for primary host to create the base dir/folder.
async def wait_for_base_dir_creation():
while not await asyncio.to_thread(base_dir.exists):
await asyncio.sleep(0.25)

try:
await asyncio.wait_for(
wait_for_base_dir_creation(), timeout=self._write_timeout_secs
)
except asyncio.TimeoutError as e:
primary_process = (
'LOCAL' if self._primary_host is None else self._primary_host
)
raise ValueError(
f'[process_index={multihost.process_index()}] Timed out waiting for'
f' array_metadatas base directory creation: {base_dir}.'
f' timeout={self._write_timeout_secs} seconds.'
f' primary_process={primary_process}'
) from e

async def write(
self,
Expand All @@ -155,7 +191,7 @@ async def write(
file_path = self._path_resolver.get_write_file_path(
checkpoint_dir, process_index
)
await asyncio.to_thread(file_path.parent.mkdir, parents=True, exist_ok=True)
await self._maybe_create_base_dir(file_path.parent)
await asyncio.to_thread(
file_path.write_text, self._ser_deser.serialize(array_metadatas)
)
Expand Down

0 comments on commit 950afa6

Please sign in to comment.