Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to enable continuous checkpointing in Gemax Prod. Note that this is an experimental feature and should not be adopted by users without explicit approval. #1530

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ properties not included in any tree mapping operations.
- Added github actions CI testing using Python versions 3.10-3.13.
- Standardize naming of the "custom metadata" field (user-supplied metadata) as
`custom_metadata`.
- Add `SaveIntervalPolicy` to better encapsulate various options around choosing
whether or not to perform a save at a particular step.

### Added
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
- A policy that allows for checkpointing as often as possible, as long as a
save is not already in progress (continuous checkpointing).
- `CommitFuture` and `HandlerAwaitableSignal` for signalling between
Checkpointing layers to enable async directory creation.
- User-provided custom PyTree metadata.
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from orbax.checkpoint import args
from orbax.checkpoint import checkpoint_utils
from orbax.checkpoint import checkpointers
from orbax.checkpoint import checkpoint_managers
from orbax.checkpoint import handlers
from orbax.checkpoint import logging
from orbax.checkpoint import metadata
Expand Down
12 changes: 12 additions & 0 deletions checkpoint/orbax/checkpoint/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ py_library(
":args",
":arrays",
":checkpoint_manager",
":checkpoint_managers",
":checkpoint_utils",
":checkpointers",
":future",
Expand Down Expand Up @@ -133,6 +134,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/path:step",
"//checkpoint/orbax/checkpoint/_src/path:utils",
"//third_party/py/jax/experimental/array_serialization:serialization",
"//orbax/checkpoint/_src/checkpoint_managers:save_interval_policy",
"//orbax/checkpoint/_src/checkpointers:abstract_checkpointer",
"//orbax/checkpoint/_src/checkpointers:async_checkpointer",
"//orbax/checkpoint/_src/handlers:handler_registration",
Expand Down Expand Up @@ -353,3 +355,13 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/arrays:types",
],
)

py_library(
name = "checkpoint_managers",
srcs = ["checkpoint_managers.py"],
deps = [
":abstract_checkpoint_manager",
":checkpoint_manager",
"//orbax/checkpoint/_src/checkpoint_managers:save_interval_policy",
],
)
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from orbax.checkpoint import args
from orbax.checkpoint import checkpoint_utils
from orbax.checkpoint import checkpointers
from orbax.checkpoint import checkpoint_managers
from orbax.checkpoint import handlers
from orbax.checkpoint import logging
from orbax.checkpoint import metadata
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines policies for the interval at which checkpoints are saved."""

import dataclasses
import typing
from typing import Container, Protocol


@dataclasses.dataclass(kw_only=True)
class StepInfo:
"""Relevant information about a checkpoint step."""

step: int
is_saving_in_progress: bool
reached_preemption: bool


@typing.runtime_checkable
class SaveIntervalPolicy(Protocol):
"""A policy that defines when to save a checkpoint.

Implementations should return True from `should_save` when saving a checkpoint
is desired at the given step.
"""

def should_save(
self, step: StepInfo, *, previous_steps: list[StepInfo]
) -> bool:
...


@dataclasses.dataclass
class FixedIntervalPolicy(SaveIntervalPolicy):
"""Checkpoint at a fixed interval."""

interval: int

def should_save(
self, step: StepInfo, *, previous_steps: list[StepInfo]
) -> bool:
del previous_steps
return step.step % self.interval == 0


@dataclasses.dataclass
class SpecificStepsPolicy(SaveIntervalPolicy):

steps: Container[int]

def should_save(
self, step: StepInfo, *, previous_steps: list[StepInfo]
) -> bool:
del previous_steps
return step.step in self.steps


class ContinuousCheckpointingPolicy(SaveIntervalPolicy):
"""Checkpoint as often as possible, as long as a save is not ongoing."""

def should_save(
self, step: StepInfo, *, previous_steps: list[StepInfo]
) -> bool:
del previous_steps
return not step.is_saving_in_progress


class PreemptionCheckpointingPolicy(SaveIntervalPolicy):
"""Save a checkpoint when a preemption is detected."""

def should_save(
self, step: StepInfo, *, previous_steps: list[StepInfo]
) -> bool:
del previous_steps
return step.reached_preemption


class InitialSavePolicy(SaveIntervalPolicy):
"""Checkpoint as soon as possible if no checkpoints already exist."""

def should_save(
self, step: StepInfo, *, previous_steps: list[StepInfo]
) -> bool:
del step
return not previous_steps


@dataclasses.dataclass
class AnySavePolicy(SaveIntervalPolicy):

policies: list[SaveIntervalPolicy]

def should_save(
self, step: StepInfo, *, previous_steps: list[StepInfo]
) -> bool:
return any(
policy.should_save(step, previous_steps=previous_steps)
for policy in self.policies
)


@dataclasses.dataclass
class AllSavePolicy(SaveIntervalPolicy):

policies: list[SaveIntervalPolicy]

def should_save(
self, step: StepInfo, *, previous_steps: list[StepInfo]
) -> bool:
return any(
policy.should_save(step, previous_steps=previous_steps)
for policy in self.policies
)
112 changes: 85 additions & 27 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint._src.checkpoint_managers import save_interval_policy as save_interval_policy_lib
from orbax.checkpoint._src.checkpointers import abstract_checkpointer
from orbax.checkpoint._src.checkpointers import async_checkpointer
from orbax.checkpoint._src.checkpointers import checkpointer as checkpointer_lib
Expand Down Expand Up @@ -157,6 +158,53 @@ def join(self, *args, **kwargs):
raise self.exception


@dataclasses.dataclass
class _ShouldSaveFnPolicy(save_interval_policy_lib.SaveIntervalPolicy):
"""A policy that uses a user-provided should_save_fn."""

should_save_fn: Callable[[int, Optional[int]], bool]

def should_save(
self,
step: save_interval_policy_lib.StepInfo,
*,
previous_steps: list[save_interval_policy_lib.StepInfo],
) -> bool:
return self.should_save_fn(
step.step, previous_steps[-1].step if previous_steps else None
)


def _get_default_save_interval_policy(
options: CheckpointManagerOptions,
) -> save_interval_policy_lib.SaveIntervalPolicy:
"""Creates a default policy from CheckpointManagerOptions."""
save_interval_policies = []
if options.should_save_fn is not None:
save_interval_policies.append(_ShouldSaveFnPolicy(options.should_save_fn))
save_interval_policies.append(
save_interval_policy_lib.PreemptionCheckpointingPolicy()
)
else:
if options.save_interval_steps is not None:
save_interval_policies.append(
save_interval_policy_lib.FixedIntervalPolicy(
options.save_interval_steps
)
)
if options.save_on_steps is not None:
save_interval_policies.append(
save_interval_policy_lib.SpecificStepsPolicy(options.save_on_steps)
)
save_interval_policies.append(
save_interval_policy_lib.PreemptionCheckpointingPolicy()
)
save_interval_policies.append(
save_interval_policy_lib.InitialSavePolicy()
)
return save_interval_policy_lib.AnySavePolicy(save_interval_policies)


# TODO(b/268051457) Clean up when no longer depended upon by internal users.
def is_async_checkpointer(checkpointer: AbstractCheckpointer):
return isinstance(
Expand Down Expand Up @@ -293,6 +341,9 @@ class CheckpointManagerOptions:
file_options: FileOptions = dataclasses.field(default_factory=FileOptions)
save_root_metadata: bool = True
temporary_path_class: Optional[Type[atomicity_types.TemporaryPath]] = None
save_interval_policy: Optional[
save_interval_policy_lib.SaveIntervalPolicy
] = None

def __post_init__(self):
if self.best_mode not in ('min', 'max'):
Expand Down Expand Up @@ -362,7 +413,7 @@ def __post_init__(self):
self.keep_period,
)
self.keep_period = None
self.save_on_steps = frozenset(self.save_on_steps or ())
self.save_on_steps = set(self.save_on_steps or ())


@dataclasses.dataclass
Expand Down Expand Up @@ -576,6 +627,10 @@ def __init__(

self._options = options or CheckpointManagerOptions()
self._multiprocessing_options = self._options.multiprocessing_options
self._save_interval_policy = (
self._options.save_interval_policy
or _get_default_save_interval_policy(self._options)
)

if self._options.best_mode not in ['min', 'max']:
raise ValueError('`best_mode` must be one of: "min", "max"')
Expand Down Expand Up @@ -720,6 +775,10 @@ def __init__(
with self._finalize_thread_lock:
self._finalize_thread = None

self._is_saving_in_progress_lock = threading.Lock()
with self._is_saving_in_progress_lock:
self._is_saving_in_progress = False

self._checkpoint_deleter: deleter.CheckpointDeleter = (
deleter.create_checkpoint_deleter(
self._multiprocessing_options.primary_host,
Expand Down Expand Up @@ -992,32 +1051,29 @@ def should_save(self, step: int) -> bool:
if self._options.read_only:
logging.warning('%s is read only, save will be skipped', self.directory)
return False
if self.reached_preemption(step):
return True
last_checkpoint_step = self.latest_step()
# Ensure current step is between the last step and next step (accounting for
# save interval). The `last_checkpoint_step` may not be initialized, in
# which case we should save. Otherwise, step must fall on the specified
# save interval. This condition accounts for the possibility of saving
# on preemption, in which case we want to maintain the same save period as
# if preemption had not happened.
# save interval).
if last_checkpoint_step is not None and last_checkpoint_step >= step:
return False
# If present then prefer should_save_fn over other 'save_*' options.
if self._options.should_save_fn is not None:
logging.log_every_n(
logging.INFO,
'CheckpointManagerOptions.should_save_fn is available, following save'
' options will be ignored: save_interval_steps=%s and'
' save_on_steps=%s',
500,
self._options.save_interval_steps,
self._options.save_on_steps,
)
return self._options.should_save_fn(step, last_checkpoint_step)
return last_checkpoint_step is None or (
step % self._options.save_interval_steps == 0
or step in self._options.save_on_steps

is_saving_in_progress = self.is_saving_in_progress()
reached_preemption = self.reached_preemption(step)
previous_step_infos = [
save_interval_policy_lib.StepInfo(
step=ckpt.step,
is_saving_in_progress=False,
reached_preemption=False,
)
for ckpt in self._checkpoints
]
current_step_info = save_interval_policy_lib.StepInfo(
step=step,
is_saving_in_progress=is_saving_in_progress,
reached_preemption=reached_preemption,
)
return self._save_interval_policy.should_save(
current_step_info, previous_steps=previous_step_infos
)

def _get_save_directory(
Expand Down Expand Up @@ -1264,6 +1320,8 @@ def save(

assert self._finalize_thread is None
if is_async_checkpointer(self._checkpointer):
with self._is_saving_in_progress_lock:
self._is_saving_in_progress = True
with self._finalize_thread_lock:
finalize_thread_name = 'save_finalize'
logging.info(
Expand Down Expand Up @@ -1769,10 +1827,8 @@ def wait_until_finished(self):

def is_saving_in_progress(self) -> bool:
"""Returns whether a checkpoint save is in progress."""
with self._finalize_thread_lock:
return (
self._finalize_thread is not None and self._finalize_thread.is_alive()
)
with self._is_saving_in_progress_lock:
return self._is_saving_in_progress

def check_for_errors(self):
"""See superclass documentation."""
Expand Down Expand Up @@ -1852,6 +1908,8 @@ def _finalize(self, step: int, steps_to_remove: List[int]):
threading.current_thread().name,
step,
)
with self._is_saving_in_progress_lock:
self._is_saving_in_progress = False

def close(self):
"""See superclass documentation."""
Expand Down
Loading
Loading