Skip to content

Commit 1c1f359

Browse files
author
Orbax Authors
committed
Fix checkpoint key value restoration when working with dictionary integer keys
PiperOrigin-RevId: 836839428
1 parent d966ddf commit 1c1f359

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+566
-290
lines changed

checkpoint/CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [0.11.30] - 2025-11-26
11+
12+
### Fixed
13+
14+
- Roll back earlier change altering metadata format, which was observed to cause
15+
breakages.
16+
17+
## [0.11.29] - 2025-11-25
18+
1019
### Fixed
1120

1221
- Fix `step_from_checkpoint_name` to allow the passed in checkpoint name to
1322
include an arbitrary `step_prefix` with any character(s) such as underscores.
1423
- Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`.
24+
- Fix using jax.eval_shape with StandardRestore
1525

1626
### Changed
1727

1828
- Validate checkpoints before writing merged OCDBT database using in-memory
1929
state, avoiding additional I/O to re-read metadata.
2030
- add `support_format` to utils.to_shape_dtype_struct()
2131
- Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`.
32+
- Replace usage of `get_json_tpec_read` and delegate functionality to new
33+
function `build_array_read_spec` which constructs and returns an
34+
`ArrayReadSpec`.
2235

2336
## [0.11.28] - 2025-11-06
2437

checkpoint/orbax/checkpoint/_src/handlers/standard_checkpoint_handler_test_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,13 @@ def handler(self) -> StandardCheckpointHandler:
505505
return StandardCheckpointHandler()
506506

507507
def test_with_random_keys(self):
508+
# TODO(b/393160483) investigate Pathways remote Python support for
509+
# random.keys.
508510
if utils.is_pathways_backend():
509-
self.skipTest('Pathways does not support random keys checkpoint.')
511+
self.skipTest(
512+
'Disabled on Pathways because random keys are not supported by'
513+
' remote Python.'
514+
)
510515

511516
def create_random_keys(seed):
512517
duplicated_sharding = jax.sharding.NamedSharding(
@@ -559,3 +564,38 @@ def create_random_keys(seed):
559564
args=self.restore_args_cls(abstract_tree),
560565
)
561566
test_utils.assert_tree_equal(self, self.pytree, restored)
567+
568+
def test_save_restore_random_keys_with_jax_eval_shape(self):
569+
# TODO(b/393160483) investigate Pathways remote Python support for
570+
# random.keys.
571+
if utils.is_pathways_backend():
572+
self.skipTest(
573+
'Disabled on Pathways because random keys are not supported by'
574+
' remote Python.'
575+
)
576+
577+
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
578+
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
579+
580+
@functools.partial(
581+
jax.jit,
582+
in_shardings=sharding,
583+
out_shardings=sharding,
584+
)
585+
def sharded_create_state_fn(root_key):
586+
return dict(
587+
matrix=jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
588+
rngkey=jax.random.fold_in(root_key, 42),
589+
)
590+
591+
pytree = sharded_create_state_fn(jax.random.key(0))
592+
abstract_pytree = jax.eval_shape(
593+
sharded_create_state_fn, jax.random.key(0)
594+
)
595+
596+
self.handler.save(self.directory, args=self.save_args_cls(pytree))
597+
598+
restored = self.handler.restore(
599+
self.directory, args=self.restore_args_cls(abstract_pytree)
600+
)
601+
test_utils.assert_tree_equal(self, pytree, restored)

checkpoint/orbax/checkpoint/_src/path/atomicity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,13 @@ async def _create_tmp_directory(
161161
def _get_tmp_directory(final_path: epath.Path) -> epath.Path:
162162
# Path may not be completely unique if a preemption occurs. We rely on the
163163
# existing tmp directory being deleted elsewhere.
164-
return epath.Path(final_path.parent) / (final_path.name + TMP_DIR_SUFFIX)
164+
return final_path.parent / (final_path.name + TMP_DIR_SUFFIX)
165165

166166

167167
def _get_final_directory(tmp_path: epath.Path) -> epath.Path:
168168
if (suffix_idx := tmp_path.name.find(TMP_DIR_SUFFIX)) == -1:
169169
raise ValueError(f'Expected {tmp_path} to end with "{TMP_DIR_SUFFIX}".')
170-
return epath.Path(tmp_path.parent) / tmp_path.name[:suffix_idx]
170+
return tmp_path.parent / tmp_path.name[:suffix_idx]
171171

172172

173173
class TemporaryPathBase(atomicity_types.TemporaryPath):

checkpoint/orbax/checkpoint/_src/path/deleter.py

Lines changed: 38 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import threading
2222
import time
2323
from typing import Optional, Protocol, Sequence
24+
from urllib import parse
2425

2526
from absl import logging
2627
from etils import epath
@@ -31,6 +32,7 @@
3132
from orbax.checkpoint._src.path import step as step_lib
3233

3334

35+
urlparse = parse.urlparse
3436
PurePosixPath = pathlib.PurePosixPath
3537

3638
_THREADED_DELETE_DURATION = (
@@ -183,7 +185,9 @@ def delete(self, step: int) -> None:
183185
# Attempt to rename using GCS HNS API if configured.
184186
if self._todelete_full_path is not None:
185187
if gcs_utils.is_gcs_path(self._directory):
186-
self._rename_gcs_step_with_hns(step, delete_target)
188+
# This is recommended for GCS buckets with HNS enabled and requires
189+
# `_todelete_full_path` to be specified.
190+
self._gcs_rename_step(step, delete_target)
187191
else:
188192
raise NotImplementedError()
189193
# Attempt to rename to local subdirectory using `todelete_subdir`
@@ -204,88 +208,56 @@ def delete(self, step: int) -> None:
204208
time.time() - start,
205209
)
206210

207-
def _rename_gcs_step_with_hns(
211+
def _gcs_rename_step(
208212
self, step: int, delete_target: epath.Path
209213
):
210-
"""Renames a GCS directory using the Storage Control API.
214+
"""Renames a GCS directory to a temporary location for deletion.
215+
216+
This method renames the directory using the
217+
underlying `tf.io.gfile.rename` method. This underlying
218+
implementation will automatically detect if the bucket is HNS-enabled
219+
and use a fast atomic rename, or fall back to a legacy
220+
copy/delete rename if it is not.
211221
212222
Args:
213223
step: The checkpoint step number.
214224
delete_target: The path to the directory to be renamed.
215-
216-
Raises:
217-
ValueError: If the GCS bucket is not HNS-enabled, as this is a
218-
hard requirement for this operation.
219225
"""
220-
logging.info(
221-
'Condition: GCS path with `todelete_full_path` set. Checking for HNS.'
222-
)
223-
bucket_name, _ = gcs_utils.parse_gcs_path(self._directory)
224-
if not gcs_utils.is_hierarchical_namespace_enabled(self._directory):
225-
raise ValueError(
226-
f'Bucket "{bucket_name}" does not have Hierarchical Namespace'
227-
' enabled, which is required when _todelete_full_path is set.'
228-
)
229-
230-
logging.info('HNS bucket detected. Attempting to rename step %d.', step)
231-
# pylint: disable=g-import-not-at-top
232-
from google.api_core import exceptions as google_exceptions # pytype: disable=import-error
233226
try:
234-
from google.cloud import storage_control_v2 # pytype: disable=import-error
235-
import google.auth # pytype: disable=import-error
236-
237-
# Use default credentials, but without a quota project to avoid
238-
# quota issues with this API.
239-
credentials, _ = google.auth.default()
240-
creds_without_quota_project = credentials.with_quota_project(None)
241-
client = storage_control_v2.StorageControlClient(
242-
credentials=creds_without_quota_project
243-
)
244-
# Destination parent is the absolute path to the bucket.
245-
destination_parent_dir_str = (
227+
# Get the bucket name from the source path
228+
bucket_name = urlparse(str(delete_target)).netloc
229+
if not bucket_name:
230+
raise ValueError(
231+
f'Could not parse bucket name from path: {delete_target}'
232+
)
233+
234+
# Construct the destination path inside the `_todelete_full_path` dir.
235+
destination_parent_path = epath.Path(
246236
f'gs://{bucket_name}/{self._todelete_full_path}'
247237
)
248-
destination_parent_path = PurePosixPath(destination_parent_dir_str)
249-
logging.info(
250-
'Ensuring destination parent folder exists via HNS API: %s',
251-
destination_parent_dir_str,
252-
)
253-
try:
254-
parent_folder_id = str(
255-
destination_parent_path.relative_to(f'gs://{bucket_name}')
256-
)
257-
bucket_resource_name = f'projects/_/buckets/{bucket_name}'
258-
client.create_folder(
259-
request=storage_control_v2.CreateFolderRequest(
260-
parent=bucket_resource_name,
261-
folder_id=parent_folder_id,
262-
recursive=True,
263-
)
264-
)
265-
logging.info('HNS parent folder creation request sent.')
266-
except google_exceptions.AlreadyExists:
267-
logging.info('HNS parent folder already exists, proceeding.')
238+
destination_parent_path.mkdir(parents=True, exist_ok=True)
268239

240+
# Create a unique name for the destination to avoid collisions.
269241
now = datetime.datetime.now()
270242
timestamp_str = now.strftime('%Y%m%d-%H%M%S-%f')
271243
new_name_with_timestamp = f'{delete_target.name}-{timestamp_str}'
272244
dest_path = destination_parent_path / new_name_with_timestamp
273-
source_folder_id = str(delete_target.relative_to(f'gs://{bucket_name}'))
274-
destination_folder_id = str(dest_path.relative_to(f'gs://{bucket_name}'))
275-
source_resource_name = (
276-
f'projects/_/buckets/{bucket_name}/folders/{source_folder_id}'
277-
)
278-
logging.info('Rename API call: Source: %s', source_resource_name)
279-
logging.info('Rename API call: Destination ID: %s', destination_folder_id)
280-
request = storage_control_v2.RenameFolderRequest(
281-
name=source_resource_name,
282-
destination_folder_id=destination_folder_id,
245+
246+
logging.info(
247+
'Executing filesystem-aware rename: Source=`%s`, Destination=`%s`',
248+
delete_target,
249+
dest_path,
283250
)
284-
op = client.rename_folder(request=request)
285-
op.result()
251+
252+
# Call the high-level rename method.
253+
# This will be fast on HNS and slow (but functional) on non-HNS.
254+
delete_target.rename(dest_path)
286255
logging.info('Successfully renamed step %d to %s', step, dest_path)
287-
except google_exceptions.GoogleAPIError as e:
288-
logging.error('HNS rename failed for step %d. Error: %s', step, e)
256+
257+
except Exception as e:
258+
message = f'Rename failed for step {step}. Error: {e}'
259+
logging.error(message)
260+
raise RuntimeError(message) from e
289261

290262
def _rename_step_to_subdir(self, step: int, delete_target: epath.Path):
291263
"""Renames a step directory to its corresponding todelete_subdir."""

checkpoint/orbax/checkpoint/_src/path/deleter_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
# limitations under the License.
1414

1515
"""To test Orbax in single-host setup."""
16+
17+
import unittest
18+
from unittest import mock
19+
1620
from absl.testing import absltest
1721
from absl.testing import parameterized
1822
from etils import epath
@@ -64,5 +68,51 @@ def test_checkpoint_deleter_delete(
6468
deleter.close()
6569

6670

71+
class GcsRenameTest(unittest.TestCase):
72+
73+
@mock.patch('orbax.checkpoint._src.path.deleter.epath.Path')
74+
def test_gcs_rename_logic_directly(self, mock_epath_constructor):
75+
"""Tests path construction and rename call logic."""
76+
standard_checkpoint_deleter = deleter_lib.StandardCheckpointDeleter
77+
78+
deleter = standard_checkpoint_deleter(
79+
directory=mock.MagicMock(),
80+
name_format=step_lib.standard_name_format(),
81+
primary_host=None,
82+
todelete_subdir=None,
83+
todelete_full_path='trash_bin',
84+
enable_hns=False,
85+
)
86+
# When epath.Path() is called inside the code, it returns this mock parent
87+
mock_dest_parent = mock.MagicMock()
88+
mock_epath_constructor.return_value = mock_dest_parent
89+
90+
# When the code does (parent / child), return a specific final mock
91+
mock_final_dest = mock.MagicMock()
92+
mock_final_dest.__str__.return_value = 'gs://mocked/final/destination'
93+
mock_dest_parent.__truediv__.return_value = mock_final_dest
94+
95+
# Setup the "Source" Mock (The step being deleted)
96+
mock_step_path = mock.MagicMock()
97+
mock_step_path.__str__.return_value = 'gs://my-bucket/checkpoints/step_10'
98+
mock_step_path.name = 'step_10'
99+
100+
deleter._gcs_rename_step(step=10, delete_target=mock_step_path)
101+
102+
# Verify mkdir was called on the destination parent.
103+
mock_dest_parent.mkdir.assert_called_with(parents=True, exist_ok=True)
104+
105+
# Verify the Parent Path string was constructed correctly
106+
# The code does: epath.Path(f'gs://{bucket}/{todelete_full_path}')
107+
(parent_path_arg,), _ = mock_epath_constructor.call_args
108+
self.assertEqual(parent_path_arg, 'gs://my-bucket/trash_bin')
109+
110+
# Verify the Child Filename was constructed correctly
111+
(child_name_arg,), _ = mock_dest_parent.__truediv__.call_args
112+
self.assertIn('step_10-', child_name_arg)
113+
114+
# Verify the Rename was actually called
115+
mock_step_path.rename.assert_called_with(mock_final_dest)
116+
67117
if __name__ == '__main__':
68118
absltest.main()

checkpoint/orbax/checkpoint/_src/path/gcs_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def get_bucket(bucket_name: str):
5050

5151
def is_hierarchical_namespace_enabled(path: epath.PathLike) -> bool:
5252
"""Return whether hierarchical namespace is enabled."""
53+
parsed = parse.urlparse(str(path))
54+
if parsed.scheme != 'gs':
55+
return False
5356
bucket_name, _ = parse_gcs_path(path)
5457
bucket = get_bucket(bucket_name)
5558
return bucket.hierarchical_namespace_enabled

checkpoint/orbax/checkpoint/_src/path/step.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,11 @@ class _StandardNameFormat(NameFormat[Metadata]):
321321
single_host_load_and_broadcast: If True, the jax process=0 will list all
322322
steps and broadcast them to all other processes. NOTE: Ignored if jax
323323
backend is not multi controller.
324-
enable_hns: Enables HNS-specific path logic.
325324
"""
326325

327326
step_prefix: Optional[str] = None
328327
step_format_fixed_length: Optional[int] = None
329328
single_host_load_and_broadcast: bool = False
330-
enable_hns: bool = False
331329

332330
def __str__(self):
333331
return f'StandardNameFormat("{self.build_name(1234)}")'
@@ -375,9 +373,7 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]:
375373
"""Returns step paths under `base_path`."""
376374
base_path = epath.Path(base_path)
377375
# <step_prefix>_?<0 padding>?*
378-
if self.enable_hns and gcs_utils.is_hierarchical_namespace_enabled(
379-
base_path
380-
):
376+
if gcs_utils.is_hierarchical_namespace_enabled(base_path):
381377
logging.vlog(
382378
1,
383379
'HNS enabled. Using GCS API to list step paths at %s',
@@ -560,7 +556,6 @@ def standard_name_format(
560556
step_prefix: Optional[str] = None,
561557
step_format_fixed_length: Optional[int] = None,
562558
single_host_load_and_broadcast: bool = False,
563-
enable_hns: bool = False,
564559
) -> NameFormat[Metadata]:
565560
"""Returns NameFormat for 'standard' steps for common Orbax use cases.
566561
@@ -580,13 +575,11 @@ def standard_name_format(
580575
single_host_load_and_broadcast: If True, the jax process=0 will list all
581576
steps and broadcast them to all other processes. NOTE: Ignored if jax
582577
backend is not multi controller.
583-
enable_hns: Enables HNS-specific path logic.
584578
"""
585579
return _StandardNameFormat(
586580
step_prefix=step_prefix,
587581
step_format_fixed_length=step_format_fixed_length,
588582
single_host_load_and_broadcast=single_host_load_and_broadcast,
589-
enable_hns=enable_hns,
590583
)
591584

592585

0 commit comments

Comments
 (0)