-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
ArrayStorageOptions
to customize per-leaf saving behavior for a…
…rrays (e.g. `dtype`). PiperOrigin-RevId: 726245749
- Loading branch information
1 parent
4a24304
commit f80dab8
Showing
21 changed files
with
625 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
23 changes: 23 additions & 0 deletions
23
checkpoint/orbax/checkpoint/experimental/v1/_src/context/BUILD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package(default_visibility = ["//visibility:public"]) | ||
|
||
py_library( | ||
name = "context", | ||
srcs = ["context.py"], | ||
deps = [":options"], | ||
) | ||
|
||
py_library( | ||
name = "options", | ||
srcs = ["options.py"], | ||
deps = ["//orbax/checkpoint/experimental/v1/_src/tree:types"], | ||
) | ||
|
||
py_test( | ||
name = "context_test", | ||
srcs = ["context_test.py"], | ||
deps = [ | ||
":context", | ||
":options", | ||
"//orbax/checkpoint/experimental/v1", | ||
], | ||
) |
71 changes: 71 additions & 0 deletions
71
checkpoint/orbax/checkpoint/experimental/v1/_src/context/context.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Orbax context for customized checkpointing.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections.abc import Iterable | ||
import contextvars | ||
import dataclasses | ||
|
||
from etils import epy | ||
from orbax.checkpoint.experimental.v1._src.context import options as options_lib | ||
|
||
|
||
# Each Thread will have its own copy of `Context` object. | ||
# Task and groups will have their own copy of `Context` object. | ||
_CONTEXT: contextvars.ContextVar[Context] = contextvars.ContextVar( | ||
"orbax_context", default=None | ||
) | ||
|
||
|
||
def get_context(default: Context | None = None) -> Context: | ||
"""Returns the current `Context` or `default` or `Context()` if not set.""" | ||
default = default or Context() | ||
return _CONTEXT.get(default) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True, kw_only=True) | ||
class Context(epy.ContextManager): | ||
"""Context for customized checkpointing. | ||
Usage example:: | ||
with ocp.Context(...): | ||
ocp.save_pytree(...) | ||
NOTE: The context is not shared across threads. In other words, the whole | ||
context block must be executed in the same thread. Following example will | ||
not work as expected:: | ||
executor = ThreadPoolExecutor() | ||
with ocp.Context(...): # Thread #1 creates Context A. | ||
executor.submit(ocp.save_pytree, ...) # Thread #2 sees "default" Context. | ||
Attributes: | ||
pytree_options: Options for PyTree checkpointing. | ||
""" | ||
|
||
pytree_options: options_lib.PyTreeOptions = dataclasses.field( | ||
default_factory=options_lib.PyTreeOptions | ||
) | ||
|
||
def __contextmanager__(self) -> Iterable[Context]: | ||
token = _CONTEXT.set(self) | ||
try: | ||
yield self | ||
finally: | ||
_CONTEXT.reset(token) |
93 changes: 93 additions & 0 deletions
93
checkpoint/orbax/checkpoint/experimental/v1/_src/context/context_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for checkpointing with Context.""" | ||
|
||
from concurrent import futures | ||
from absl.testing import absltest | ||
import orbax.checkpoint.experimental.v1 as ocp | ||
from orbax.checkpoint.experimental.v1._src.context import context as context_lib | ||
from orbax.checkpoint.experimental.v1._src.context import options as options_lib | ||
|
||
PyTreeOptions = options_lib.PyTreeOptions | ||
|
||
|
||
def fake_checkpoint_operation() -> ocp.Context: | ||
return context_lib.get_context() | ||
|
||
|
||
class ContextTest(absltest.TestCase): | ||
|
||
def test_default_context(self): | ||
ctx = fake_checkpoint_operation() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions()) | ||
|
||
with ocp.Context(): | ||
ctx = fake_checkpoint_operation() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions()) | ||
|
||
context = ocp.Context() | ||
with context: | ||
ctx = fake_checkpoint_operation() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions()) | ||
|
||
def test_custom_context(self): | ||
with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)): | ||
ctx = fake_checkpoint_operation() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False)) | ||
|
||
context = ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)) | ||
with context: | ||
ctx = fake_checkpoint_operation() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False)) | ||
|
||
def test_custom_context_in_separate_thread_becomes_default(self): | ||
with futures.ThreadPoolExecutor(max_workers=1) as executor: | ||
with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)): | ||
future = executor.submit(fake_checkpoint_operation) | ||
ctx = future.result() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions()) | ||
|
||
with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)): | ||
with futures.ThreadPoolExecutor(max_workers=1) as executor: | ||
future = executor.submit(fake_checkpoint_operation) | ||
ctx = future.result() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions()) | ||
|
||
def test_custom_context_in_same_thread_remains_custom(self): | ||
def test_fn(): | ||
with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)): | ||
ctx = fake_checkpoint_operation() | ||
return ctx | ||
|
||
with futures.ThreadPoolExecutor(max_workers=1) as executor: | ||
future = executor.submit(test_fn) | ||
ctx = future.result() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False)) | ||
|
||
def test_nested_contexts(self): | ||
with ocp.Context(pytree_options=PyTreeOptions(use_zarr3=False)): | ||
ctx = fake_checkpoint_operation() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False)) | ||
|
||
with ocp.Context(pytree_options=PyTreeOptions(use_ocdbt=False)): | ||
ctx = fake_checkpoint_operation() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_ocdbt=False)) | ||
|
||
ctx = fake_checkpoint_operation() | ||
self.assertEqual(ctx.pytree_options, PyTreeOptions(use_zarr3=False)) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
98 changes: 98 additions & 0 deletions
98
checkpoint/orbax/checkpoint/experimental/v1/_src/context/options.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright 2024 The Orbax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Configurable options for customizing checkpointing behavior.""" | ||
|
||
from __future__ import annotations | ||
|
||
import dataclasses | ||
from typing import Any | ||
from typing import Protocol | ||
|
||
import numpy as np | ||
from orbax.checkpoint.experimental.v1._src.tree import types as tree_types | ||
|
||
|
||
@dataclasses.dataclass(frozen=True, kw_only=True) | ||
class PyTreeOptions: | ||
"""Options for PyTree checkpointing. | ||
Attributes: | ||
use_ocdbt: Whether to use OCDBT for saving. | ||
use_zarr3: Whether to use Zarr3 for saving. | ||
save_concurrent_bytes: The maximum number of bytes to save concurrently. | ||
restore_concurrent_bytes: The maximum number of bytes to restore | ||
concurrently. | ||
ocdbt_target_data_file_size: Specifies the target size (in bytes) of each | ||
OCDBT data file. It only applies when OCDBT is enabled and Zarr3 must be | ||
turned on. If left unspecified, default size is 2GB. A value of 0 | ||
indicates no maximum file size limit. For best results, ensure | ||
chunk_byte_size is smaller than this value. For more details, refer to | ||
https://google.github.io/tensorstore/kvstore/ocdbt/index.html#json-kvstore/ocdbt.target_data_file_size | ||
enable_pinned_host_transfer: If False, disables transfer to pinned host when | ||
copying from device to host, regardless of the presence of pinned host | ||
memory. | ||
partial_load: If the tree structure omits some keys relative to the | ||
checkpoint, the omitted keys will not be loaded. | ||
array_storage_options_creator: A function that is applied to each leaf of | ||
the input PyTree (via `jax.tree.map_with_path`) to create a | ||
`ArrayStorageOptions` object, which is used to customize saving behavior | ||
for individual leaves. See `ArrayStorageOptions` and | ||
`ArrayStorageOptionsCreator` for more details. | ||
""" | ||
|
||
use_ocdbt: bool = True | ||
use_zarr3: bool = True | ||
save_concurrent_bytes: int | None = None | ||
restore_concurrent_bytes: int | None = None | ||
ocdbt_target_data_file_size: int | None = None | ||
enable_pinned_host_transfer: bool = False | ||
partial_load: bool = False | ||
array_storage_options_creator: ArrayStorageOptionsCreator | None = None | ||
|
||
|
||
@dataclasses.dataclass | ||
class ArrayStorageOptions: | ||
"""Arguments used to customize array storage behavior for individual leaves. | ||
dtype: | ||
If provided, casts the parameter to the given dtype before saving. | ||
Note that the parameter must be compatible with the given type (e.g. | ||
jnp.bfloat16 is not compatible with np.ndarray). | ||
chunk_byte_size: | ||
This is an experimental feature that automatically chooses the largest chunk | ||
shape possible, while keeping the chunk byte size less than or equal to the | ||
specified chunk_byte_size. Both the write_chunk_shape and read_chunk_shape | ||
are automatically set to the chosen shape. This uses a greedy algorithm that | ||
prioritizes splitting the largest dimensions first. | ||
shard_axes: An optional list of axes that should be prioritized when | ||
sharding array for storage. If empty, storage sharding implementation will | ||
prioritize axes which are already sharded. | ||
""" | ||
|
||
dtype: np.typing.DTypeLike | None = None | ||
chunk_byte_size: int | None = None | ||
shard_axes: tuple[int, ...] = tuple() | ||
|
||
def __post_init__(self): | ||
if self.dtype is not None: | ||
self.dtype = np.dtype(self.dtype) | ||
|
||
|
||
class ArrayStorageOptionsCreator(Protocol): | ||
|
||
def __call__( | ||
self, key: tree_types.PyTreeKeyPath, value: Any | ||
) -> ArrayStorageOptions: | ||
... |
Oops, something went wrong.