@@ -251,14 +251,24 @@ def __init__(self):
251251 """ ,
252252 )
253253 self .parser .add_argument (
254- "--lr_scheduler.lr_min " ,
254+ "--lr_scheduler.min_lr_factor " ,
255255 type = float ,
256256 default = 0.0 ,
257257 help = """
258258 Min lr ratio for lr scheduler.
259259
260- If provided, the range of decay factor is scaled from 1 to `lr_min`
261- to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.lr_min`.
260+ If provided, the range of decay factor is scaled from 1 to `min_lr_factor`
261+ to ensure the learning rate does not drop below `optimizer.lr * lr_scheduler.min_lr_factor`.
262+ """ ,
263+ )
264+ self .parser .add_argument (
265+ "--lr_scheduler.total_steps" ,
266+ type = int ,
267+ default = None ,
268+ help = """
269+ Total steps for LR schedule calculation. If None, defaults to training.steps.
270+ Lets the LR schedule be decoupled from the actual training steps, useful for
271+ early stopping or debug-length runs that should follow the full-training curve.
262272 """ ,
263273 )
264274
@@ -502,6 +512,23 @@ def __init__(self):
502512 action = "store_true" ,
503513 help = "Whether to apply async tensor parallel (currently only effective when compile is enabled)" ,
504514 )
515+ # Torchtitan 0.2 moved most parallelism knobs into a dedicated `parallelism`
516+ # section. We still expose them under --training.* and --experimental.* for
517+ # backwards compatibility with existing scripts; `_validate_config` mirrors
518+ # the values into a `parallelism` subconfig so torchtitan internals can read
519+ # them under the new names (e.g. `job_config.parallelism.pipeline_parallel_schedule`).
520+ self .parser .add_argument (
521+ "--parallelism.pipeline_parallel_schedule" ,
522+ type = str ,
523+ default = None ,
524+ help = "[torchtitan 0.2] Pipeline parallel schedule. If unset, mirrors --experimental.pipeline_parallel_schedule." ,
525+ )
526+ self .parser .add_argument (
527+ "--parallelism.context_parallel_load_balancer" ,
528+ type = str ,
529+ default = "headtail" ,
530+ help = "Load balancer type for context parallelism (passed through to torchtitan 0.2)." ,
531+ )
505532 self .parser .add_argument (
506533 "--experimental.pipeline_parallel_degree" ,
507534 type = int ,
@@ -595,19 +622,18 @@ def __init__(self):
595622 # with TorchFT.
596623 # This option is subject to change and may be deleted in the future.
597624 self .parser .add_argument (
598- "--experimental.custom_model_path " ,
625+ "--experimental.custom_import " ,
599626 type = str ,
600627 default = "" ,
601628 help = """
602- The --custom_model_path option allows to specify a custom path to a model module
603- that is not natively implemented within TorchTitan.
604- Acceptable values are the file system path to the module (e.g., my_models/model_x)
605- dotted import module (e.g., some_package.model_x).
629+ Import a custom model module by dotted import path (e.g. `some_package.model_x`).
630+ Use this to register external model definitions that aren't natively implemented
631+ within torchtitan / flame.
606632 """ ,
607633 )
608634 # checkpointing configs
609635 self .parser .add_argument (
610- "--checkpoint.enable_checkpoint " ,
636+ "--checkpoint.enable " ,
611637 action = "store_true" ,
612638 help = "Whether to enable checkpoint" ,
613639 )
@@ -617,7 +643,7 @@ def __init__(self):
617643 default = "checkpoint" ,
618644 help = """
619645 The folder to store the checkpoints.
620- When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
646+ When enable is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
621647 """ ,
622648 )
623649 self .parser .add_argument (
@@ -631,29 +657,57 @@ def __init__(self):
631657 This feature allows users to load an initial checkpoint from a different folder and
632658 continue training, saving new checkpoints to the specified folder without affecting
633659 the existing ones.
634-
660+
635661 Note that the path should contain the full path to the checkpoint folder,
636662 including the step number, if any; for example,
637663 "//pre_train/checkpoints/llama3/llama3_8b/step_10000".
638664 """
639665 )
640666 self .parser .add_argument (
641- "--checkpoint.initial_load_model_weights_only " ,
642- dest = 'checkpoint.initial_load_model_weights_only ' , action = "store_true" , default = True ,
667+ "--checkpoint.initial_load_model_only " ,
668+ dest = 'checkpoint.initial_load_model_only ' , action = "store_true" , default = True ,
643669 help = """
644- This option specifies if only the model weights should be loaded during the initial
645- checkpoint load. The option is only used when `initial_load_path` is specified, and
646- only applies to a model_weights_only checkpoint. Loading a periodic checkpoint
647- may lead to unexpected behavior if this option is set to True.
670+ If True, only the model weights are loaded during the initial checkpoint load.
648671 If False, the checkpoint at `initial_load_path` is treated as a standard training
649- checkpoint, including optimizer and training states.
650- The default setting for this option is True. Note that you will have to use
651- `--checkpoint.no_initial_load_model_weights_only` to override the default setting.
672+ checkpoint, including optimizer and training states. Use
673+ `--checkpoint.no_initial_load_model_only` to set to False.
652674 """
653675 )
654676 self .parser .add_argument (
655- "--checkpoint.no_initial_load_model_weights_only" ,
656- dest = 'checkpoint.initial_load_model_weights_only' , action = "store_false" ,
677+ "--checkpoint.no_initial_load_model_only" ,
678+ dest = 'checkpoint.initial_load_model_only' , action = "store_false" ,
679+ )
680+ self .parser .add_argument (
681+ "--checkpoint.initial_load_in_hf" ,
682+ action = "store_true" ,
683+ help = "Load the initial checkpoint from HF safetensors format." ,
684+ )
685+ self .parser .add_argument (
686+ "--checkpoint.initial_load_in_hf_quantized" ,
687+ action = "store_true" ,
688+ help = "Load initial HF safetensors checkpoint with quantized keys (requires a HF storage reader)." ,
689+ )
690+ self .parser .add_argument (
691+ "--checkpoint.enable_first_step_checkpoint" ,
692+ action = "store_true" ,
693+ help = "Save a checkpoint after step 1 (useful to validate checkpointing end-to-end)." ,
694+ )
695+ self .parser .add_argument (
696+ "--checkpoint.enable_ft_dataloader_checkpoints" ,
697+ dest = "checkpoint.enable_ft_dataloader_checkpoints" ,
698+ action = "store_true" ,
699+ default = True ,
700+ help = "Snapshot dataloader index in checkpoints (needed for fault-tolerant training)." ,
701+ )
702+ self .parser .add_argument (
703+ "--checkpoint.no_enable_ft_dataloader_checkpoints" ,
704+ dest = "checkpoint.enable_ft_dataloader_checkpoints" ,
705+ action = "store_false" ,
706+ )
707+ self .parser .add_argument (
708+ "--checkpoint.load_only" ,
709+ action = "store_true" ,
710+ help = "Only load checkpoints; do not save new ones (useful for verification)." ,
657711 )
658712 self .parser .add_argument (
659713 "--checkpoint.interval" ,
@@ -662,16 +716,20 @@ def __init__(self):
662716 help = "Checkpointing interval in steps." ,
663717 )
664718 self .parser .add_argument (
665- "--checkpoint.last_save_model_weights_only " ,
719+ "--checkpoint.last_save_model_only " ,
666720 action = "store_true" ,
667721 help = """
668- When last_save_model_weights_only=True, only model weights will be saved at the end of training,
669- the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
670- after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
671- A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
672- The default value is false.
722+ When True, only model weights are saved at the end of training (the last save).
723+ With this, checkpoints can be loaded via `torch.load(..., weights_only=True)` after
724+ conversion. When False, the full checkpoint is saved (model + optimizer + state),
725+ which can be used to resume training. Default is False.
673726 """ ,
674727 )
728+ self .parser .add_argument (
729+ "--checkpoint.last_save_in_hf" ,
730+ action = "store_true" ,
731+ help = "Save the last checkpoint as HF safetensors. Requires last_save_model_only=True." ,
732+ )
675733 self .parser .add_argument (
676734 "--checkpoint.export_dtype" ,
677735 type = str ,
@@ -820,6 +878,30 @@ def __init__(self):
820878 default = 20000 ,
821879 help = "Flight recorder ring buffer size, >0 means recording by default, 0 means disabled" ,
822880 )
881+ self .parser .add_argument (
882+ "--comm.save_traces_folder" ,
883+ type = str ,
884+ default = "comm_traces" ,
885+ help = "Flight recorder trace files location." ,
886+ )
887+ self .parser .add_argument (
888+ "--comm.save_traces_file_prefix" ,
889+ type = str ,
890+ default = "rank_" ,
891+ help = "Flight recorder trace files prefix." ,
892+ )
893+ self .parser .add_argument (
894+ "--comm.mode" ,
895+ type = str ,
896+ default = "default" ,
897+ choices = ["default" , "fake_backend" , "local_tensor" ],
898+ help = """
899+ Communication mode for distributed training.
900+ - "default": Normal distributed training with real communication.
901+ - "fake_backend": Fake comm backend for dry run / config validation without GPU.
902+ - "local_tensor": Simulate distributed training in a single process for debugging.
903+ """ ,
904+ )
823905
824906 # memory estimation settings
825907 self .parser .add_argument (
@@ -924,6 +1006,42 @@ def _validate_config(self) -> None:
9241006 assert self .model .config
9251007 assert self .model .tokenizer_path
9261008
1009+ # Populate a `parallelism` subconfig mirroring the parallelism knobs that
1010+ # torchtitan >= 0.2 reads off `job_config.parallelism.*`. We keep flame's
1011+ # original --training.* / --experimental.* flags (they pre-date torchtitan's
1012+ # split) and just forward them here into the shape torchtitan expects.
1013+ parallelism_values = {
1014+ "pipeline_parallel_schedule" : (
1015+ getattr (self .parallelism , "pipeline_parallel_schedule" , None )
1016+ or getattr (self .experimental , "pipeline_parallel_schedule" , "1F1B" )
1017+ ),
1018+ "context_parallel_load_balancer" : getattr (
1019+ self .parallelism , "context_parallel_load_balancer" , "headtail"
1020+ ),
1021+ "pipeline_parallel_degree" : getattr (self .experimental , "pipeline_parallel_degree" , 1 ),
1022+ "pipeline_parallel_split_points" : getattr (self .experimental , "pipeline_parallel_split_points" , []),
1023+ "pipeline_parallel_microbatches" : getattr (self .experimental , "pipeline_parallel_microbatches" , None ),
1024+ "pipeline_parallel_schedule_csv" : getattr (self .experimental , "pipeline_parallel_schedule_csv" , "" ),
1025+ "context_parallel_degree" : getattr (self .experimental , "context_parallel_degree" , 1 ),
1026+ "context_parallel_rotate_method" : getattr (self .experimental , "context_parallel_rotate_method" , "allgather" ),
1027+ "tensor_parallel_degree" : getattr (self .training , "tensor_parallel_degree" , 1 ),
1028+ "data_parallel_shard_degree" : getattr (self .training , "data_parallel_shard_degree" , - 1 ),
1029+ "data_parallel_replicate_degree" : getattr (self .training , "data_parallel_replicate_degree" , 1 ),
1030+ "disable_loss_parallel" : getattr (self .training , "disable_loss_parallel" , False ),
1031+ "enable_async_tensor_parallel" : getattr (self .experimental , "enable_async_tensor_parallel" , False ),
1032+ "expert_parallel_degree" : 1 ,
1033+ "expert_tensor_parallel_degree" : 1 ,
1034+ "fsdp_reshard_after_forward" : getattr (self .training , "fsdp_reshard_after_forward" , "default" ),
1035+ }
1036+ self .parallelism = type ("Parallelism" , (), parallelism_values )()
1037+
1038+ # Ensure `fault_tolerance.enable` / `replica_id` exist — torchtitan's
1039+ # metrics processor unconditionally reads them.
1040+ if not hasattr (self .fault_tolerance , "enable" ):
1041+ self .fault_tolerance .enable = False
1042+ if not hasattr (self .fault_tolerance , "replica_id" ):
1043+ self .fault_tolerance .replica_id = 0
1044+
9271045 def _get_string_list_argument_names (self ) -> list [str ]:
9281046 """Get the parser argument names of type `string_list`."""
9291047 string_list_args = [
0 commit comments