Skip to content

Save and quit on sigint and sigterm #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

from fast_llm.config import Configurable
from fast_llm.core.distributed import safe_barrier
from fast_llm.core.distributed import allreduce_scalar, safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank
Expand All @@ -23,7 +23,7 @@
from fast_llm.engine.training.config import TrainerConfig, TrainingCheckpointBaseConfig, TrainingCheckpointConfig
from fast_llm.engine.training.wandb import Wandb
from fast_llm.logging import format_metrics, get_memory_usage_mib, log_memory_usage
from fast_llm.utils import Assert
from fast_llm.utils import Assert, Interrupter

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -214,6 +214,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
distributed_config=self._config.model.distributed, start_step=self._completed_steps
)

interrupter = Interrupter(self._config.training.checkpoint.enabled())
train_iterator = self._get_data_iterator(
PhaseType.training.value,
self._completed_steps,
Expand All @@ -231,7 +232,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
start_iteration = self._completed_steps
last_iteration = start_iteration
stop = False
with profiler:
with profiler, interrupter:
while not stop:
# Iteration starts at 1, so we increment at the beginning.
self._completed_steps += 1
Expand Down Expand Up @@ -317,8 +318,7 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
profiler.step()

done = self._completed_steps >= self._config.training.train_iters
# TODO: Signal-based stop.
stop = done or self._config.training.shutdown.enabled(self._completed_steps)

# Evaluation
# TODO: Adjust valid iterator length.
if PhaseType.validation in self._samples_per_split and (
Expand Down Expand Up @@ -366,11 +366,19 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:
if is_main_rank() and metrics:
self._wandb.log_metrics(self._completed_steps, metrics)

if self._config.training.checkpoint.enabled(None if stop else self._completed_steps):
self._save_checkpoint(self._config.training.checkpoint, metrics)
stop = done or self._config.training.shutdown.enabled(self._completed_steps)

if self._config.training.export.enabled(None if done else self._completed_steps):
self._save_checkpoint(self._config.training.export, metrics)

if interrupter.enabled:
stop = stop or allreduce_scalar(
interrupter.interrupted, torch.int32, self._distributed.world_group
)

if self._config.training.checkpoint.enabled(None if stop else self._completed_steps):
self._save_checkpoint(self._config.training.checkpoint, metrics)

# The profiler calls the trace_fn at the end and this could lead to
profiler.step()
return done, metrics
Expand Down
32 changes: 32 additions & 0 deletions fast_llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import logging
import math
import signal
import typing
from typing import Callable

Expand Down Expand Up @@ -336,3 +337,34 @@ def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple
def check_equal_nested(config_a, config_b):
if errors := compare_nested(config_a, config_b):
raise ValueError("\n".join(errors))


class Interrupter:
def __init__(self, enabled: bool = True, signals: typing.Sequence[int] = (signal.SIGINT, signal.SIGTERM)):
self._enabled = enabled
self._signals = signals

def __enter__(self):
self._interrupted = False
self._old_signals = (
{signum: signal.signal(signum, self._handle_signal) for signum in self._signals} if self._enabled else {}
)

def __exit__(self, exc_type, exc_val, exc_tb):
for signum, handler in self._old_signals.items():
signal.signal(signum, handler)

def _handle_signal(self, signum, frame):
logger.info(f"Interrupt signal {signal.Signals(signum).name} received.")
if self._interrupted:
# Raise for a repeated signal, ex. if a user really wants to ctrl-C.
self._old_signals[signum](signum, frame)
self._interrupted = True

@property
def enabled(self) -> bool:
return self._enabled

@property
def interrupted(self):
return self._interrupted
1 change: 1 addition & 0 deletions tests/test_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_transformer_mtp(config_dict: dict[str, typing.Any]):
loss.backward()


@pytest.mark.skip(reason="Too slow")
@requires_cuda
@pytest.mark.skipif(not run_hybrid_test, reason="No CUDA available or Mamba not installed")
@pytest.mark.parametrize(
Expand Down