Skip to content

Commit d1390df

Browse files
committed
Align flame with torchtitan >= 0.2
Update the training stack to the torchtitan 0.2 surface: - `JobConfig` moves the parallelism knobs into a dedicated `parallelism.*` section and renames several checkpoint / lr_scheduler / experimental fields (`enable_checkpoint` -> `enable`, `lr_min` -> `min_lr_factor`, `last_save_model_weights_only` -> `last_save_model_only`, `custom_model_path` -> `custom_import`, ...). - Various torchtitan internals (CheckpointManager, build_lr_schedulers, MetricsProcessor, init_distributed) require new attributes on their sub-configs (`last_save_in_hf`, `total_steps`, `comm.mode`, `fault_tolerance.enable`, `parallelism.pipeline_parallel_schedule`, ...). - `torch.nn.attention.varlen` and new inductor options that torchtitan 0.2 imports/uses at load time are missing on older torch builds. Changes: - `flame/config_manager.py`: canonical names only (no BC aliases), new required fields added, and a `parallelism` subconfig mirrored from the existing `training.*` / `experimental.*` flags so torchtitan internals see what they expect. - `flame/train.py`: simplified custom-import path, cleaned up to match the new APIs (init_distributed takes comm sub-config, ParallelDims builds meshes up front, MetricsProcessor constructed after parallelisms). - `flame/models/parallelize_fla.py` / `pipeline_fla.py`: updated signatures to the new protocol; PP still unsupported for HF-based models and raises NotImplementedError cleanly. - `flame/__init__.py`: install best-effort shims (varlen stub, tolerant torch.compile) so torchtitan 0.2 imports cleanly on torch builds that predate `torch.nn.attention.varlen` / new inductor options. - `flame/models/fla.toml`, `train.sh`, `README.md`: switch to the new flag names; README drops the stale 300-line argparse dump for a concise flag-group overview plus a legacy -> new migration table. Tests: `pytest tests/test_data.py -q` (25 passed).
1 parent 24dea91 commit d1390df

8 files changed

Lines changed: 475 additions & 717 deletions

File tree

README.md

Lines changed: 44 additions & 300 deletions
Large diffs are not rendered by default.

flame/__init__.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,73 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""flame package root.
4+
5+
This module installs small compatibility shims so that torchtitan >= 0.2
6+
can be imported on top of older torch builds that predate some of the
7+
symbols / inductor options torchtitan assumes. The shims are best-effort
8+
and only suppress import-time failures; they do NOT try to silently
9+
emulate functionality that doesn't exist. If you actually invoke a
10+
feature that relies on a missing primitive (e.g. varlen attention on
11+
torch < 2.10), you get a clear runtime error.
12+
"""
13+
14+
import re
15+
import sys
16+
import types
17+
18+
# ---------------------------------------------------------------------------
19+
# Shim 1: torch.nn.attention.varlen (introduced in torch 2.10). torchtitan
20+
# imports ``varlen_attn`` at module load, so provide a stub if the real
21+
# module is absent.
22+
# ---------------------------------------------------------------------------
23+
if "torch.nn.attention.varlen" not in sys.modules:
24+
try:
25+
import torch.nn.attention.varlen # noqa: F401
26+
except ImportError:
27+
import torch.nn.attention as _attn_pkg
28+
29+
_stub = types.ModuleType("torch.nn.attention.varlen")
30+
31+
def _missing_varlen_attn(*args, **kwargs):
32+
raise RuntimeError(
33+
"torch.nn.attention.varlen.varlen_attn is not available in this "
34+
"torch build. Upgrade to torch >= 2.10 to use varlen attention."
35+
)
36+
37+
_stub.varlen_attn = _missing_varlen_attn
38+
sys.modules["torch.nn.attention.varlen"] = _stub
39+
setattr(_attn_pkg, "varlen", _stub)
40+
41+
# ---------------------------------------------------------------------------
42+
# Shim 2: torch.compile option tolerance. torchtitan pins a few inductor
43+
# options (e.g. ``wrap_inductor_compiled_regions``) that only land in newer
44+
# torch builds. On older builds, torch.compile raises RuntimeError at call
45+
# time. Wrap torch.compile so that unknown options are dropped with a
46+
# warning instead of aborting the whole import chain.
47+
# ---------------------------------------------------------------------------
48+
import torch as _torch # noqa: E402
49+
50+
_orig_compile = _torch.compile
51+
_UNKNOWN_OPT_RE = re.compile(r"Unexpected optimization option (\S+?)[,\s]")
52+
53+
54+
def _tolerant_compile(*args, **kwargs):
55+
options = kwargs.get("options")
56+
if not options:
57+
return _orig_compile(*args, **kwargs)
58+
fixed = dict(options)
59+
while True:
60+
try:
61+
kwargs["options"] = fixed
62+
return _orig_compile(*args, **kwargs)
63+
except RuntimeError as exc:
64+
m = _UNKNOWN_OPT_RE.search(str(exc))
65+
if m is None or not fixed:
66+
raise
67+
fixed.pop(m.group(1), None)
68+
69+
70+
_tolerant_compile.__wrapped__ = _orig_compile
71+
_torch.compile = _tolerant_compile
72+
173
__version__ = "0.1.0"

flame/config_manager.py

Lines changed: 146 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,24 @@ def __init__(self):
251251
""",
252252
)
253253
self.parser.add_argument(
254-
"--lr_scheduler.lr_min",
254+
"--lr_scheduler.min_lr_factor",
255255
type=float,
256256
default=0.0,
257257
help="""
258258
Min lr ratio for lr scheduler.
259259
260-
If provided, the range of decay factor is scaled from 1 to `lr_min`
261-
to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
260+
If provided, the range of decay factor is scaled from 1 to `min_lr_factor`
261+
to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.min_lr_factor`.
262+
""",
263+
)
264+
self.parser.add_argument(
265+
"--lr_scheduler.total_steps",
266+
type=int,
267+
default=None,
268+
help="""
269+
Total steps for LR schedule calculation. If None, defaults to training.steps.
270+
Lets the LR schedule be decoupled from the actual training steps, useful for
271+
early stopping or debug-length runs that should follow the full-training curve.
262272
""",
263273
)
264274

@@ -502,6 +512,23 @@ def __init__(self):
502512
action="store_true",
503513
help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
504514
)
515+
# Torchtitan 0.2 moved most parallelism knobs into a dedicated `parallelism`
516+
# section. We still expose them under --training.* and --experimental.* for
517+
# backwards compatibility with existing scripts; `_validate_config` mirrors
518+
# the values into a `parallelism` subconfig so torchtitan internals can read
519+
# them under the new names (e.g. `job_config.parallelism.pipeline_parallel_schedule`).
520+
self.parser.add_argument(
521+
"--parallelism.pipeline_parallel_schedule",
522+
type=str,
523+
default=None,
524+
help="[torchtitan 0.2] Pipeline parallel schedule. If unset, mirrors --experimental.pipeline_parallel_schedule.",
525+
)
526+
self.parser.add_argument(
527+
"--parallelism.context_parallel_load_balancer",
528+
type=str,
529+
default="headtail",
530+
help="Load balancer type for context parallelism (passed through to torchtitan 0.2).",
531+
)
505532
self.parser.add_argument(
506533
"--experimental.pipeline_parallel_degree",
507534
type=int,
@@ -595,19 +622,18 @@ def __init__(self):
595622
# with TorchFT.
596623
# This option is subject to change and may be deleted in the future.
597624
self.parser.add_argument(
598-
"--experimental.custom_model_path",
625+
"--experimental.custom_import",
599626
type=str,
600627
default="",
601628
help="""
602-
The --custom_model_path option allows to specify a custom path to a model module
603-
that is not natively implemented within TorchTitan.
604-
Acceptable values are the file system path to the module (e.g., my_models/model_x)
605-
dotted import module (e.g., some_package.model_x).
629+
Import a custom model module by dotted import path (e.g. `some_package.model_x`).
630+
Use this to register external model definitions that aren't natively implemented
631+
within torchtitan / flame.
606632
""",
607633
)
608634
# checkpointing configs
609635
self.parser.add_argument(
610-
"--checkpoint.enable_checkpoint",
636+
"--checkpoint.enable",
611637
action="store_true",
612638
help="Whether to enable checkpoint",
613639
)
@@ -617,7 +643,7 @@ def __init__(self):
617643
default="checkpoint",
618644
help="""
619645
The folder to store the checkpoints.
620-
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
646+
When enable is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
621647
""",
622648
)
623649
self.parser.add_argument(
@@ -631,29 +657,57 @@ def __init__(self):
631657
This feature allows users to load an initial checkpoint from a different folder and
632658
continue training, saving new checkpoints to the specified folder without affecting
633659
the existing ones.
634-
660+
635661
Note that the path should contain the full path to the checkpoint folder,
636662
including the step number, if any; for example,
637663
"//pre_train/checkpoints/llama3/llama3_8b/step_10000".
638664
"""
639665
)
640666
self.parser.add_argument(
641-
"--checkpoint.initial_load_model_weights_only",
642-
dest='checkpoint.initial_load_model_weights_only', action="store_true", default=True,
667+
"--checkpoint.initial_load_model_only",
668+
dest='checkpoint.initial_load_model_only', action="store_true", default=True,
643669
help="""
644-
This option specifies if only the model weights should be loaded during the initial
645-
checkpoint load. The option is only used when `initial_load_path` is specified, and
646-
only applies to a model_weights_only checkpoint. Loading a periodic checkpoint
647-
may lead to unexpected behavior if this option is set to True.
670+
If True, only the model weights are loaded during the initial checkpoint load.
648671
If False, the checkpoint at `initial_load_path` is treated as a standard training
649-
checkpoint, including optimizer and training states.
650-
The default setting for this option is True. Note that you will have to use
651-
`--checkpoint.no_initial_load_model_weights_only` to override the default setting.
672+
checkpoint, including optimizer and training states. Use
673+
`--checkpoint.no_initial_load_model_only` to set to False.
652674
"""
653675
)
654676
self.parser.add_argument(
655-
"--checkpoint.no_initial_load_model_weights_only",
656-
dest='checkpoint.initial_load_model_weights_only', action="store_false",
677+
"--checkpoint.no_initial_load_model_only",
678+
dest='checkpoint.initial_load_model_only', action="store_false",
679+
)
680+
self.parser.add_argument(
681+
"--checkpoint.initial_load_in_hf",
682+
action="store_true",
683+
help="Load the initial checkpoint from HF safetensors format.",
684+
)
685+
self.parser.add_argument(
686+
"--checkpoint.initial_load_in_hf_quantized",
687+
action="store_true",
688+
help="Load initial HF safetensors checkpoint with quantized keys (requires a HF storage reader).",
689+
)
690+
self.parser.add_argument(
691+
"--checkpoint.enable_first_step_checkpoint",
692+
action="store_true",
693+
help="Save a checkpoint after step 1 (useful to validate checkpointing end-to-end).",
694+
)
695+
self.parser.add_argument(
696+
"--checkpoint.enable_ft_dataloader_checkpoints",
697+
dest="checkpoint.enable_ft_dataloader_checkpoints",
698+
action="store_true",
699+
default=True,
700+
help="Snapshot dataloader index in checkpoints (needed for fault-tolerant training).",
701+
)
702+
self.parser.add_argument(
703+
"--checkpoint.no_enable_ft_dataloader_checkpoints",
704+
dest="checkpoint.enable_ft_dataloader_checkpoints",
705+
action="store_false",
706+
)
707+
self.parser.add_argument(
708+
"--checkpoint.load_only",
709+
action="store_true",
710+
help="Only load checkpoints; do not save new ones (useful for verification).",
657711
)
658712
self.parser.add_argument(
659713
"--checkpoint.interval",
@@ -662,16 +716,20 @@ def __init__(self):
662716
help="Checkpointing interval in steps.",
663717
)
664718
self.parser.add_argument(
665-
"--checkpoint.last_save_model_weights_only",
719+
"--checkpoint.last_save_model_only",
666720
action="store_true",
667721
help="""
668-
When last_save_model_weights_only=True, only model weights will be saved at the end of training,
669-
the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
670-
after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
671-
A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
672-
The default value is false.
722+
When True, only model weights are saved at the end of training (the last save).
723+
With this, checkpoints can be loaded via `torch.load(..., weights_only=True)` after
724+
conversion. When False, the full checkpoint is saved (model + optimizer + state),
725+
which can be used to resume training. Default is False.
673726
""",
674727
)
728+
self.parser.add_argument(
729+
"--checkpoint.last_save_in_hf",
730+
action="store_true",
731+
help="Save the last checkpoint as HF safetensors. Requires last_save_model_only=True.",
732+
)
675733
self.parser.add_argument(
676734
"--checkpoint.export_dtype",
677735
type=str,
@@ -820,6 +878,30 @@ def __init__(self):
820878
default=20000,
821879
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
822880
)
881+
self.parser.add_argument(
882+
"--comm.save_traces_folder",
883+
type=str,
884+
default="comm_traces",
885+
help="Flight recorder trace files location.",
886+
)
887+
self.parser.add_argument(
888+
"--comm.save_traces_file_prefix",
889+
type=str,
890+
default="rank_",
891+
help="Flight recorder trace files prefix.",
892+
)
893+
self.parser.add_argument(
894+
"--comm.mode",
895+
type=str,
896+
default="default",
897+
choices=["default", "fake_backend", "local_tensor"],
898+
help="""
899+
Communication mode for distributed training.
900+
- "default": Normal distributed training with real communication.
901+
- "fake_backend": Fake comm backend for dry run / config validation without GPU.
902+
- "local_tensor": Simulate distributed training in a single process for debugging.
903+
""",
904+
)
823905

824906
# memory estimation settings
825907
self.parser.add_argument(
@@ -924,6 +1006,42 @@ def _validate_config(self) -> None:
9241006
assert self.model.config
9251007
assert self.model.tokenizer_path
9261008

1009+
# Populate a `parallelism` subconfig mirroring the parallelism knobs that
1010+
# torchtitan >= 0.2 reads off `job_config.parallelism.*`. We keep flame's
1011+
# original --training.* / --experimental.* flags (they pre-date torchtitan's
1012+
# split) and just forward them here into the shape torchtitan expects.
1013+
parallelism_values = {
1014+
"pipeline_parallel_schedule": (
1015+
getattr(self.parallelism, "pipeline_parallel_schedule", None)
1016+
or getattr(self.experimental, "pipeline_parallel_schedule", "1F1B")
1017+
),
1018+
"context_parallel_load_balancer": getattr(
1019+
self.parallelism, "context_parallel_load_balancer", "headtail"
1020+
),
1021+
"pipeline_parallel_degree": getattr(self.experimental, "pipeline_parallel_degree", 1),
1022+
"pipeline_parallel_split_points": getattr(self.experimental, "pipeline_parallel_split_points", []),
1023+
"pipeline_parallel_microbatches": getattr(self.experimental, "pipeline_parallel_microbatches", None),
1024+
"pipeline_parallel_schedule_csv": getattr(self.experimental, "pipeline_parallel_schedule_csv", ""),
1025+
"context_parallel_degree": getattr(self.experimental, "context_parallel_degree", 1),
1026+
"context_parallel_rotate_method": getattr(self.experimental, "context_parallel_rotate_method", "allgather"),
1027+
"tensor_parallel_degree": getattr(self.training, "tensor_parallel_degree", 1),
1028+
"data_parallel_shard_degree": getattr(self.training, "data_parallel_shard_degree", -1),
1029+
"data_parallel_replicate_degree": getattr(self.training, "data_parallel_replicate_degree", 1),
1030+
"disable_loss_parallel": getattr(self.training, "disable_loss_parallel", False),
1031+
"enable_async_tensor_parallel": getattr(self.experimental, "enable_async_tensor_parallel", False),
1032+
"expert_parallel_degree": 1,
1033+
"expert_tensor_parallel_degree": 1,
1034+
"fsdp_reshard_after_forward": getattr(self.training, "fsdp_reshard_after_forward", "default"),
1035+
}
1036+
self.parallelism = type("Parallelism", (), parallelism_values)()
1037+
1038+
# Ensure `fault_tolerance.enable` / `replica_id` exist — torchtitan's
1039+
# metrics processor unconditionally reads them.
1040+
if not hasattr(self.fault_tolerance, "enable"):
1041+
self.fault_tolerance.enable = False
1042+
if not hasattr(self.fault_tolerance, "replica_id"):
1043+
self.fault_tolerance.replica_id = 0
1044+
9271045
def _get_string_list_argument_names(self) -> list[str]:
9281046
"""Get the parser argument names of type `string_list`."""
9291047
string_list_args = [

flame/models/fla.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ lr = 3e-4
3535
[lr_scheduler]
3636
warmup_steps = 1024
3737
decay_type = "cosine"
38-
lr_min = 0.1
38+
min_lr_factor = 0.1
3939

4040
[checkpoint]
41-
enable_checkpoint = true
41+
enable = true
4242
folder = "checkpoint"
43-
interval_type = "steps"
4443
interval = 2048
45-
model_weights_only = false
44+
# Save the full checkpoint (not weights-only) so training can resume from here.
45+
last_save_model_only = false
4646
export_dtype = "float32"
4747
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
4848

@@ -64,4 +64,4 @@ enable_fsdp_float8_all_gather = false
6464
precompute_float8_dynamic_scale_for_fsdp = false
6565

6666
[activation_checkpoint]
67-
mode = "none"
67+
mode = "none"

0 commit comments

Comments
 (0)