Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions examples/smollm2/conf/135m.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
defaults:
- _self_
- train: 135m

experiment:
exp_name: SmolLM2-135m
seed: 42
save_steps: 512
load: null
exp_dir: checkpoints/${experiment.exp_name}
ckpt_format: torch
task:
type: train
backend: megatron
entrypoint: flagscale/train/train_gpt.py
runner:
per_node_task: false
no_shared_fs: false
rdzv_backend: static
# nnodes: 1
nproc_per_node: 1
cmds:
before_start: ulimit -n 1048576 && source /root/miniconda3/bin/activate flagscale-train
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The path /root/miniconda3/bin/activate is hardcoded. This makes the configuration less portable and might fail in environments where miniconda is installed elsewhere. Consider using an environment variable or documenting this as a prerequisite with a clear setup instruction.

envs:
NVTE_DEBUG: 1
NVTE_DEBUG_LEVEL: 2
# CUDNN_LOGERR_DBG: 1
CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7
CUDA_DEVICE_MAX_CONNECTIONS: 1
# NCCL_IB_DISABLE: 0
NCCL_IB_CUDA_SUPPORT: 1
NCCL_IB_GID_INDEX: 0
#TORCH_NCCL_USE_COMM_NONBLOCKING: 1
OMP_NUM_THREADS: 4
ENABLE_FLASH_ATTENTION_WITH_IXDNN: 0
NCCL_NET_PLUGIN: none
#NCCL_SHM_DISABLE=1
NCCL_ALGO: Ring
NCCL_P2P_NET_CHUNKSIZE: 1048576
NCCL_CHUNK_SIZE: 1048576
NCCL_BUFFSIZE: 8388608
NCCL_MAX_NCHANNELS: 4
NCCL_MIN_NCHANNELS: 4
NCCL_MAX_P2P_NCHANNELS: 1
NCCL_PROTO: Simple
NCCL_NET_SHARED_BUFFERS: 0
NCCL_P2P_LL_THRESHOLD: 0
IXCCL_MIX_NV: 1
IXCCL_FUSED_ENABLE: 0
NCCL_IB_DISABLE: 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The configuration key NCCL_IB_DISABLE is defined here, but it's also present (though commented out) on line 30. This duplication can be confusing. Please remove the redundant entry to improve clarity.

NCCL_IB_HCA: mlx5_2,mlx5_5

action: run

hydra:
run:
dir: ${experiment.exp_dir}/hydra
92 changes: 92 additions & 0 deletions examples/smollm2/conf/train/135m.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
system:
reset_position_ids: True
reset_attention_mask: True
logging_level: 10
distributed_timeout_minutes: 60
no_shared_fs: ${experiment.runner.no_shared_fs}
num_workers: 4
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
context_parallel_size: 1
disable_bias_linear: true
# qk_layernorm: true
sequence_parallel: true
use_distributed_optimizer: true
overlap_grad_reduce: true
overlap_param_gather: true
finetune: false
precision:
bf16: true
attention_softmax_in_fp32: true
accumulate_allreduce_grads_in_fp32: true
logging:
log_interval: 1
tensorboard_log_interval: 1
wandb_project: ${experiment.exp_name}
wandb_exp_name: ${experiment.exp_name}
log_timers_to_tensorboard: true
log_validation_ppl_to_tensorboard: true
log_throughput: true
log_params_norm: true
log_num_zeros_in_grad: true
log_memory_to_tensorboard: true
checkpoint:
save_interval: ${experiment.save_steps}
load: ${experiment.load}
ckpt_format: ${experiment.ckpt_format}

model:
transformer_impl: transformer_engine
num_layers: 30
hidden_size: 576
ffn_hidden_size: 1536
kv_channels: 64
group_query_attention: True
Comment on lines +2 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's an inconsistency in the representation of boolean values in this YAML file. Some are capitalized (True on lines 2, 3, 44), while others are lowercase (true/false). For consistency and to follow common YAML conventions, it's best to use lowercase for all boolean values.

num_attention_heads: 9
num_query_groups: 3 # num_key_value_heads
seq_length: 4096
max_position_embeddings: 4096
norm_epsilon: 1e-6
use_rotary_position_embeddings: true
rotary_base: 100000
swiglu: true
normalization: RMSNorm
init_method_std: 6e-3
attention_dropout: 0.0
hidden_dropout: 0.0
clip_grad: 1.0
position_embedding_type: rope
untie_embeddings_and_output_weights: false
no_position_embedding: true
no_rope_fusion: true

# training
seed: ${experiment.seed}
micro_batch_size: 1
global_batch_size: 1
eval_iters: 0
train_samples: 16777216 #69B tokens

optimizer:
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 3.0e-4
min_lr: 3.0e-5
lr_warmup_samples: 1024000
lr_decay_style: cosine


data:
data_path:
- 1.0
- MegaScience_text_document

split: 1
no_mmap_bin_files: true
tokenizer:
tokenizer_type: SmolLM2TokenizerFS
tokenizer_path: /share/projset/ldwang/models/HuggingFaceTB/SmolLM2-135M
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The tokenizer_path is a hardcoded absolute path. This will prevent the training script from running on any machine other than the one where this path is valid. Please replace it with a placeholder or a relative path to ensure the example is reproducible.

vocab_size: 49152
make_vocab_size_divisible_by: 64
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py
index 5cf222cc..35ad2b85 100644
index 5cf222ccc..eaa16fc8b 100644
--- a/megatron/training/tokenizer/tokenizer.py
+++ b/megatron/training/tokenizer/tokenizer.py
@@ -11,9 +11,10 @@ from pathlib import Path
Expand All @@ -24,7 +24,7 @@ index 5cf222cc..35ad2b85 100644
elif args.tokenizer_type == 'NullTokenizer':
assert args.vocab_size is not None
tokenizer = _NullTokenizer(args.vocab_size)
@@ -102,6 +106,31 @@ def build_tokenizer(args, **kwargs):
@@ -102,6 +106,34 @@ def build_tokenizer(args, **kwargs):
elif args.tokenizer_type == 'NullMultimodalTokenizer':
assert args.vocab_size is not None
tokenizer = _NullMultimodalTokenizer(args.vocab_size)
Expand All @@ -43,6 +43,9 @@ index 5cf222cc..35ad2b85 100644
+ elif args.tokenizer_type == "QwenTokenizerFS":
+ assert args.tokenizer_path is not None
+ tokenizer = _QwenTokenizerFS(args.tokenizer_path)
+ elif args.tokenizer_type == "SmolLM2TokenizerFS":
+ assert args.tokenizer_path is not None
+ tokenizer = _SmolLM2TokenizerFS(args.tokenizer_path)
+ elif args.tokenizer_type == "HFTokenizersTokenizerFS":
+ assert args.tokenizer_path is not None
+ tokenizer = _HFTokenizersTokenizerFS(args.tokenizer_path)
Expand All @@ -56,7 +59,7 @@ index 5cf222cc..35ad2b85 100644
else:
raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type))

@@ -596,6 +625,16 @@ class _Llama2Tokenizer(_SentencePieceTokenizer):
@@ -596,6 +628,16 @@ class _Llama2Tokenizer(_SentencePieceTokenizer):
t = t + [self.eos_id]
return t

Expand All @@ -73,7 +76,7 @@ index 5cf222cc..35ad2b85 100644
def detokenize(self, ids):
return self.tokenizer.decode_ids(ids)

@@ -909,3 +948,276 @@ class _NullMultimodalTokenizer(MegatronTokenizer):
@@ -909,3 +951,286 @@ class _NullMultimodalTokenizer(MegatronTokenizer):
@property
def additional_special_tokens_ids(self):
return None
Expand Down Expand Up @@ -196,6 +199,16 @@ index 5cf222cc..35ad2b85 100644
+ self.pad_id = self.tokenizer.encode('<|endoftext|>')[0]
+
+
+class _SmolLM2TokenizerFS(_HFTokenizerFS):
+ """Adapted SmolLM2 tokenizer."""
+
+ def __init__(self, tokenizer_path):
+ super().__init__(tokenizer_path)
+ self.eod_id = self.tokenizer.encode('<|endoftext|>')[0]
+ self.cls_id = self.tokenizer.encode('<|endoftext|>')[0]
+ self.pad_id = self.tokenizer.encode('<|endoftext|>')[0]
Comment on lines +207 to +209
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The token <|endoftext|> is encoded three separate times to set eod_id, cls_id, and pad_id. This is inefficient. It's better to encode the token once and reuse the resulting ID.

        endoftext_id = self.tokenizer.encode('<|endoftext|>')[0]
        self.eod_id = endoftext_id
        self.cls_id = endoftext_id
        self.pad_id = endoftext_id

+
+
+class _HFTokenizersTokenizerFS(MegatronTokenizer):
+ """Tokenizer from HuggingFace Tokenizers."""
+
Expand Down
2 changes: 1 addition & 1 deletion tools/checkpoint/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main():
default=[],
nargs="+",
required=True,
choices=["aquila", "mistral", "mixtral", "llama", "deepseek_v3", "qwen3"],
choices=["aquila", "mistral", "mixtral", "llama", "deepseek_v3", "qwen3", "smollm2"],
help="Type of the model.",
)
parser.add_argument(
Expand Down
94 changes: 94 additions & 0 deletions tools/checkpoint/smollm2/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
import os


def load_args_hf2mg(args):

# Read llama args.
llama_args_path = os.path.join(args.load, "config.json")
with open(llama_args_path) as f:
llama_args = json.load(f)
Comment on lines +8 to +10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable names llama_args_path and llama_args are misleading as this script is for the smollm2 model. This appears to be a copy-paste from the llama conversion script. Please rename them to something more generic like hf_config_path and hf_config to avoid confusion.


# Update Megatron args.
args.attention_dropout = llama_args["attention_dropout"]
args.hidden_dropout = llama_args["attention_dropout"]
args.hidden_size = llama_args["hidden_size"]
args.swiglu = llama_args["hidden_act"] == "silu"
args.init_method_std = llama_args["initializer_range"]
args.ffn_hidden_size = llama_args["intermediate_size"]
args.max_position_embeddings = llama_args["max_position_embeddings"]
args.model_type = llama_args["model_type"]
args.num_attention_heads = llama_args["num_attention_heads"]
args.num_layers = llama_args["num_hidden_layers"]
args.num_query_groups = llama_args["num_key_value_heads"]
args.norm_epsilon = llama_args["rms_norm_eps"]
args.rotary_seq_len_interpolation_factor = (
None if llama_args["rope_scaling"] == "null" else llama_args["rope_scaling"]
)
Comment on lines +25 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check llama_args["rope_scaling"] == "null" is likely incorrect. In JSON, null is a keyword, not a string. When loaded by Python's json library, it becomes None. The check should be llama_args["rope_scaling"] is None.

Suggested change
args.rotary_seq_len_interpolation_factor = (
None if llama_args["rope_scaling"] == "null" else llama_args["rope_scaling"]
)
args.rotary_seq_len_interpolation_factor = (
None if llama_args.get("rope_scaling") is None else llama_args["rope_scaling"]
)

args.rotary_base = llama_args["rope_theta"]
args.untie_embeddings_and_output_weights = not llama_args["tie_word_embeddings"]
args.bf16 = llama_args["dtype"] == "bfloat16"
args.fp16 = llama_args["dtype"] == "float16"
Comment on lines +30 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Hugging Face config.json for Llama-like models typically uses the key torch_dtype, not dtype. Using dtype here will likely result in a KeyError. This seems to be a copy-paste error. Please verify the key in the smollm2 model's config.json and use the correct one, which is probably torch_dtype.

Suggested change
args.bf16 = llama_args["dtype"] == "bfloat16"
args.fp16 = llama_args["dtype"] == "float16"
args.bf16 = llama_args["torch_dtype"] == "bfloat16"
args.fp16 = llama_args["torch_dtype"] == "float16"

args.vocab_size = llama_args["vocab_size"]
args.padded_vocab_size = llama_args["vocab_size"]

args.seq_length = 2048
args.global_batch_size = 1024
args.iteration = 1 # '0', 'release' don't work
args.add_position_embedding = False
args.group_query_attention = True
args.normalization = "RMSNorm"
args.use_rotary_position_embeddings = True
args.add_bias_linear = False
args.add_qkv_bias = False
args.make_vocab_size_divisible_by = 64
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.norm_has_bias = False
args.tokenizer_type = "Llama3TokenizerFS"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The tokenizer_type is hardcoded to Llama3TokenizerFS. Since this script is for smollm2, it should be set to SmolLM2TokenizerFS. This is a critical bug that will cause the checkpoint conversion to fail or produce an incorrect result.

Suggested change
args.tokenizer_type = "Llama3TokenizerFS"
args.tokenizer_type = "SmolLM2TokenizerFS"



def save_args_mg2hf(args):
from transformers import LlamaConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The script imports and uses LlamaConfig. This is misleading and potentially incorrect for the smollm2 model. If smollm2 has its own config class, that should be used. If not, it's better to use a more generic AutoConfig to avoid issues if the architectures are not perfectly compatible.


config = LlamaConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
intermediate_size=args.ffn_hidden_size,
num_hidden_layers=args.encoder_num_layers,
num_attention_heads=args.num_attention_heads,
num_key_value_heads=args.num_query_groups,
hidden_act="silu" if args.swiglu else False,
max_position_embeddings=args.max_position_embeddings,
initializer_range=args.init_method_std,
rms_norm_eps=args.norm_epsilon,
use_cache=True,
tie_word_embeddings=not args.untie_embeddings_and_output_weights,
rope_theta=args.rotary_base,
rope_scaling=args.rotary_seq_len_interpolation_factor,
attention_bias=args.add_qkv_bias,
attention_dropout=args.attention_dropout,
torch_dtype=args.params_dtype,
bias_dropout_fusion=args.bias_dropout_fusion,
end_weight_decay=args.end_weight_decay,
global_batch_size=args.global_batch_size,
hidden_dropout=args.hidden_dropout,
lr=args.lr,
lr_decay_style=args.lr_decay_style,
make_vocab_size_divisible_by=args.make_vocab_size_divisible_by,
masked_softmax_fusion=args.masked_softmax_fusion,
min_lr=args.min_lr,
norm_init_weight=args.norm_init_weight,
perform_initialization=args.perform_initialization,
reset_attention_mask=args.reset_attention_mask,
reset_position_ids=args.reset_position_ids,
rotary_base=args.rotary_base,
seed=args.seed,
split=args.split,
start_weight_decay=args.start_weight_decay,
use_flash_attn=args.use_flash_attn,
weight_decay_incr_style=args.weight_decay_incr_style,
)
config.save_pretrained(args.save)

return config
Loading
Loading