Skip to content

Commit

Permalink
Add ArrayStorageOptions to customize per-leaf saving behavior for a…
Browse files Browse the repository at this point in the history
…rrays (e.g. `dtype`).

PiperOrigin-RevId: 726245749
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Mar 1, 2025
1 parent 4a24304 commit 5ffbb6e
Show file tree
Hide file tree
Showing 21 changed files with 625 additions and 103 deletions.
13 changes: 12 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- V1: Introduce Context and its usage as a contextmanager.
- V1: Add `ArrayStorageOptions` to customize per-leaf saving behavior for
arrays (e.g. `dtype`).

### Fixed

- Fix RESOURCE_EXHAUSTED while writing array_metadatas.

### Changed

- Improve `Cannot serialize host local jax.Array` error message.

### Added

- support saving and restoring jax.random.key() in PyTree
- support saving and restoring jax.random.key() in PyTree.
- `CheckpointableHandler` for V1.

## [0.11.6] - 2025-02-20

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def __init__(
type_handler_registry
)
)
if self._array_metadata_store:
self._array_metadata_store.set_primary_host(self._primary_host)
self._array_metadata_validator = array_metadata_validator


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,45 @@ def __init__(
self,
path_resolver: PathResolver = PathResolver(),
ser_deser: SerDeserializer = SerDeserializer(),
primary_host: int | None = 0, # None means all hosts are primary hosts.
write_timeout_secs: int = 600, # 10 minutes.
):
self._path_resolver = path_resolver
self._ser_deser = ser_deser
self._primary_host = primary_host
self._write_timeout_secs = write_timeout_secs

def set_primary_host(self, primary_host: int | None) -> None:
"""Sets the primary host."""
self._primary_host = primary_host

async def _maybe_create_base_dir(self, base_dir: epath.Path) -> None:
"""Primary host creates the base directory, rest of the hosts wait."""
if multihost.is_primary_host(self._primary_host):
# primary host creates, rest of the hosts wait.
return await asyncio.to_thread(
base_dir.mkdir, parents=True, exist_ok=True
)

# non-primary host waits for primary host to create the base dir/folder.
async def wait_for_base_dir_creation():
while not await asyncio.to_thread(base_dir.exists):
await asyncio.sleep(0.25)

try:
await asyncio.wait_for(
wait_for_base_dir_creation(), timeout=self._write_timeout_secs
)
except asyncio.TimeoutError as e:
primary_process = (
'LOCAL' if self._primary_host is None else self._primary_host
)
raise ValueError(
f'[process_index={multihost.process_index()}] Timed out waiting for'
f' array_metadatas base directory creation: {base_dir}.'
f' timeout={self._write_timeout_secs} seconds.'
f' primary_process={primary_process}'
) from e

async def write(
self,
Expand All @@ -155,7 +191,7 @@ async def write(
file_path = self._path_resolver.get_write_file_path(
checkpoint_dir, process_index
)
await asyncio.to_thread(file_path.parent.mkdir, parents=True, exist_ok=True)
await self._maybe_create_base_dir(file_path.parent)
await asyncio.to_thread(
file_path.write_text, self._ser_deser.serialize(array_metadatas)
)
Expand Down
11 changes: 10 additions & 1 deletion checkpoint/orbax/checkpoint/experimental/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines exported symbols for the namespace package `orbax.checkpoint.v1`."""
"""Defines exported symbols for `orbax.checkpoint.experimental.v1`.
Prefer to use the style::
import orbax.checkpoint.experimental.v1 as ocp
"""

# pylint: disable=g-importing-member, g-multiple-import

from orbax.checkpoint.experimental.v1._src.context import options
from orbax.checkpoint.experimental.v1._src.context.context import (
Context,
)
from orbax.checkpoint.experimental.v1._src.loading.loading import (
load_pytree,
load_pytree_async,
Expand Down
23 changes: 23 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/context/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package(default_visibility = ["//visibility:public"])

py_library(
name = "context",
srcs = ["context.py"],
deps = [":options"],
)

py_library(
name = "options",
srcs = ["options.py"],
deps = ["//orbax/checkpoint/experimental/v1/_src/tree:types"],
)

py_test(
name = "context_test",
srcs = ["context_test.py"],
deps = [
":context",
":options",
"//orbax/checkpoint/experimental/v1",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Orbax context for customized checkpointing."""

from __future__ import annotations

from collections.abc import Iterable
import contextvars
import dataclasses

from etils import epy
from orbax.checkpoint.experimental.v1._src.context import options as options_lib


# Each Thread will have its own copy of `Context` object.
# Task and groups will have their own copy of `Context` object.
_CONTEXT: contextvars.ContextVar[Context] = contextvars.ContextVar(
"orbax_context", default=None
)


def get_context(default: Context | None = None) -> Context:
"""Returns the current `Context` or `default` or `Context()` if not set."""
default = default or Context()
return _CONTEXT.get(default)


@dataclasses.dataclass(frozen=True, kw_only=True)
class Context(epy.ContextManager):
"""Context for customized checkpointing.
Usage example::
with ocp.Context(...):
ocp.save_pytree(...)
NOTE: The context is not shared across threads. In other words, the whole
context block must be executed in the same thread. Following example will
not work as expected::
executor = ThreadPoolExecutor()
with ocp.Context(...): # Thread #1 creates Context A.
executor.submit(ocp.save_pytree, ...) # Thread #2 sees "default" Context.
Attributes:
pytree_options: Options for PyTree checkpointing.
"""

pytree_options: options_lib.PyTreeOptions = dataclasses.field(
default_factory=options_lib.PyTreeOptions
)

def __contextmanager__(self) -> Iterable[Context]:
token = _CONTEXT.set(self)
try:
yield self
finally:
_CONTEXT.reset(token)
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for checkpointing with Context."""

from concurrent import futures
from absl.testing import absltest
import orbax.checkpoint.experimental.v1 as ocp
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.context import options as options_lib

PyTreeOptions = options_lib.PyTreeOptions


def fake_checkpoint_operation() -> ocp.Context:
return context_lib.get_context()


class ContextTest(absltest.TestCase):

def test_default_context(self):
ctx = fake_checkpoint_operation()
self.assertEqual(ctx.pytree_options, PyTreeOptions())

with ocp.Context():
ctx = fake_checkpoint_operation()
self.assertEqual(ctx.pytree_options, PyTreeOptions())

context = ocp.Context()
with context:
ctx = fake_checkpoint_operation()
self.assertEqual(ctx.pytree_options, PyTreeOptions())

def test_custom_context(self):
with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)):
ctx = fake_checkpoint_operation()
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False))

context = ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False))
with context:
ctx = fake_checkpoint_operation()
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False))

def test_custom_context_in_separate_thread_becomes_default(self):
with futures.ThreadPoolExecutor(max_workers=1) as executor:
with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)):
future = executor.submit(fake_checkpoint_operation)
ctx = future.result()
self.assertEqual(ctx.pytree_options, PyTreeOptions())

with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)):
with futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(fake_checkpoint_operation)
ctx = future.result()
self.assertEqual(ctx.pytree_options, PyTreeOptions())

def test_custom_context_in_same_thread_remains_custom(self):
def test_fn():
with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)):
ctx = fake_checkpoint_operation()
return ctx

with futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(test_fn)
ctx = future.result()
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False))

def test_nested_contexts(self):
with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)):
ctx = fake_checkpoint_operation()
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False))

with ocp.Context(pytree_options=PyTreeOptions(use_ocdbt=False)):
ctx = fake_checkpoint_operation()
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_ocdbt=False))

ctx = fake_checkpoint_operation()
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False))


if __name__ == "__main__":
absltest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Configurable options for customizing checkpointing behavior."""

from __future__ import annotations

import dataclasses
from typing import Any
from typing import Protocol

import numpy as np
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types


@dataclasses.dataclass(frozen=True, kw_only=True)
class PyTreeOptions:
"""Options for PyTree checkpointing.
Attributes:
use_ocdbt: Whether to use OCDBT for saving.
use_zarr3: Whether to use Zarr3 for saving.
save_concurrent_bytes: The maximum number of bytes to save concurrently.
restore_concurrent_bytes: The maximum number of bytes to restore
concurrently.
ocdbt_target_data_file_size: Specifies the target size (in bytes) of each
OCDBT data file. It only applies when OCDBT is enabled and Zarr3 must be
turned on. If left unspecified, default size is 2GB. A value of 0
indicates no maximum file size limit. For best results, ensure
chunk_byte_size is smaller than this value. For more details, refer to
https://google.github.io/tensorstore/kvstore/ocdbt/index.html#json-kvstore/ocdbt.target_data_file_size
enable_pinned_host_transfer: If False, disables transfer to pinned host when
copying from device to host, regardless of the presence of pinned host
memory.
partial_load: If the tree structure omits some keys relative to the
checkpoint, the omitted keys will not be loaded.
array_storage_options_creator: A function that is applied to each leaf of
the input PyTree (via `jax.tree.map_with_path`) to create a
`ArrayStorageOptions` object, which is used to customize saving behavior
for individual leaves. See `ArrayStorageOptions` and
`ArrayStorageOptionsCreator` for more details.
"""

use_ocdbt: bool = True
use_zarr3: bool = True
save_concurrent_bytes: int | None = None
restore_concurrent_bytes: int | None = None
ocdbt_target_data_file_size: int | None = None
enable_pinned_host_transfer: bool = False
partial_load: bool = False
array_storage_options_creator: ArrayStorageOptionsCreator | None = None


@dataclasses.dataclass
class ArrayStorageOptions:
"""Arguments used to customize array storage behavior for individual leaves.
dtype:
If provided, casts the parameter to the given dtype before saving.
Note that the parameter must be compatible with the given type (e.g.
jnp.bfloat16 is not compatible with np.ndarray).
chunk_byte_size:
This is an experimental feature that automatically chooses the largest chunk
shape possible, while keeping the chunk byte size less than or equal to the
specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape
are automatically set to the chosen shape. This uses a greedy algorithm that
prioritizes splitting the largest dimensions first.
shard_axes: An optional list of axes that should be prioritized when
sharding array for storage. If empty, storage sharding implementation will
prioritize axes which are already sharded.
"""

dtype: np.typing.DTypeLike | None = None
chunk_byte_size: int | None = None
shard_axes: tuple[int, ...] = tuple()

def __post_init__(self):
if self.dtype is not None:
self.dtype = np.dtype(self.dtype)


class ArrayStorageOptionsCreator(Protocol):

def __call__(
self, key: tree_types.PyTreeKeyPath, value: Any
) -> ArrayStorageOptions:
...
Loading

0 comments on commit 5ffbb6e

Please sign in to comment.