diff --git a/examples/smollm2/conf/135m.yaml b/examples/smollm2/conf/135m.yaml new file mode 100644 index 000000000..43638726d --- /dev/null +++ b/examples/smollm2/conf/135m.yaml @@ -0,0 +1,57 @@ +defaults: + - _self_ + - train: 135m + +experiment: + exp_name: SmolLM2-135m + seed: 42 + save_steps: 512 + load: null + exp_dir: checkpoints/${experiment.exp_name} + ckpt_format: torch + task: + type: train + backend: megatron + entrypoint: flagscale/train/train_gpt.py + runner: + per_node_task: false + no_shared_fs: false + rdzv_backend: static + # nnodes: 1 + nproc_per_node: 1 + cmds: + before_start: ulimit -n 1048576 && source /root/miniconda3/bin/activate flagscale-train + envs: + NVTE_DEBUG: 1 + NVTE_DEBUG_LEVEL: 2 + # CUDNN_LOGERR_DBG: 1 + CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + # NCCL_IB_DISABLE: 0 + NCCL_IB_CUDA_SUPPORT: 1 + NCCL_IB_GID_INDEX: 0 + #TORCH_NCCL_USE_COMM_NONBLOCKING: 1 + OMP_NUM_THREADS: 4 + ENABLE_FLASH_ATTENTION_WITH_IXDNN: 0 + NCCL_NET_PLUGIN: none + #NCCL_SHM_DISABLE=1 + NCCL_ALGO: Ring + NCCL_P2P_NET_CHUNKSIZE: 1048576 + NCCL_CHUNK_SIZE: 1048576 + NCCL_BUFFSIZE: 8388608 + NCCL_MAX_NCHANNELS: 4 + NCCL_MIN_NCHANNELS: 4 + NCCL_MAX_P2P_NCHANNELS: 1 + NCCL_PROTO: Simple + NCCL_NET_SHARED_BUFFERS: 0 + NCCL_P2P_LL_THRESHOLD: 0 + IXCCL_MIX_NV: 1 + IXCCL_FUSED_ENABLE: 0 + NCCL_IB_DISABLE: 0 + NCCL_IB_HCA: mlx5_2,mlx5_5 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/examples/smollm2/conf/train/135m.yaml b/examples/smollm2/conf/train/135m.yaml new file mode 100644 index 000000000..0dca7c530 --- /dev/null +++ b/examples/smollm2/conf/train/135m.yaml @@ -0,0 +1,92 @@ +system: + reset_position_ids: True + reset_attention_mask: True + logging_level: 10 + distributed_timeout_minutes: 60 + no_shared_fs: ${experiment.runner.no_shared_fs} + num_workers: 4 + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + disable_bias_linear: true + # qk_layernorm: true + sequence_parallel: true + use_distributed_optimizer: true + overlap_grad_reduce: true + overlap_param_gather: true + finetune: false + precision: + bf16: true + attention_softmax_in_fp32: true + accumulate_allreduce_grads_in_fp32: true + logging: + log_interval: 1 + tensorboard_log_interval: 1 + wandb_project: ${experiment.exp_name} + wandb_exp_name: ${experiment.exp_name} + log_timers_to_tensorboard: true + log_validation_ppl_to_tensorboard: true + log_throughput: true + log_params_norm: true + log_num_zeros_in_grad: true + log_memory_to_tensorboard: true + checkpoint: + save_interval: ${experiment.save_steps} + load: ${experiment.load} + ckpt_format: ${experiment.ckpt_format} + +model: + transformer_impl: transformer_engine + num_layers: 30 + hidden_size: 576 + ffn_hidden_size: 1536 + kv_channels: 64 + group_query_attention: True + num_attention_heads: 9 + num_query_groups: 3 # num_key_value_heads + seq_length: 4096 + max_position_embeddings: 4096 + norm_epsilon: 1e-6 + use_rotary_position_embeddings: true + rotary_base: 100000 + swiglu: true + normalization: RMSNorm + init_method_std: 6e-3 + attention_dropout: 0.0 + hidden_dropout: 0.0 + clip_grad: 1.0 + position_embedding_type: rope + untie_embeddings_and_output_weights: false + no_position_embedding: true + no_rope_fusion: true + + # training + seed: ${experiment.seed} + micro_batch_size: 1 + global_batch_size: 1 + eval_iters: 0 + train_samples: 16777216 #69B tokens + + optimizer: + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + lr_scheduler: + lr: 3.0e-4 + min_lr: 3.0e-5 + lr_warmup_samples: 1024000 + lr_decay_style: cosine + + +data: + data_path: + - 1.0 + - MegaScience_text_document + + split: 1 + no_mmap_bin_files: true + tokenizer: + tokenizer_type: SmolLM2TokenizerFS + tokenizer_path: /share/projset/ldwang/models/HuggingFaceTB/SmolLM2-135M + vocab_size: 49152 + make_vocab_size_divisible_by: 64 diff --git a/flagscale/backends/Megatron-LM/megatron/training/tokenizer/tokenizer.py.patch b/flagscale/backends/Megatron-LM/megatron/training/tokenizer/tokenizer.py.patch index 260e82978..7bdb12f55 100644 --- a/flagscale/backends/Megatron-LM/megatron/training/tokenizer/tokenizer.py.patch +++ b/flagscale/backends/Megatron-LM/megatron/training/tokenizer/tokenizer.py.patch @@ -1,5 +1,5 @@ diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py -index 5cf222cc..35ad2b85 100644 +index 5cf222ccc..eaa16fc8b 100644 --- a/megatron/training/tokenizer/tokenizer.py +++ b/megatron/training/tokenizer/tokenizer.py @@ -11,9 +11,10 @@ from pathlib import Path @@ -24,7 +24,7 @@ index 5cf222cc..35ad2b85 100644 elif args.tokenizer_type == 'NullTokenizer': assert args.vocab_size is not None tokenizer = _NullTokenizer(args.vocab_size) -@@ -102,6 +106,31 @@ def build_tokenizer(args, **kwargs): +@@ -102,6 +106,34 @@ def build_tokenizer(args, **kwargs): elif args.tokenizer_type == 'NullMultimodalTokenizer': assert args.vocab_size is not None tokenizer = _NullMultimodalTokenizer(args.vocab_size) @@ -43,6 +43,9 @@ index 5cf222cc..35ad2b85 100644 + elif args.tokenizer_type == "QwenTokenizerFS": + assert args.tokenizer_path is not None + tokenizer = _QwenTokenizerFS(args.tokenizer_path) ++ elif args.tokenizer_type == "SmolLM2TokenizerFS": ++ assert args.tokenizer_path is not None ++ tokenizer = _SmolLM2TokenizerFS(args.tokenizer_path) + elif args.tokenizer_type == "HFTokenizersTokenizerFS": + assert args.tokenizer_path is not None + tokenizer = _HFTokenizersTokenizerFS(args.tokenizer_path) @@ -56,7 +59,7 @@ index 5cf222cc..35ad2b85 100644 else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(args.tokenizer_type)) -@@ -596,6 +625,16 @@ class _Llama2Tokenizer(_SentencePieceTokenizer): +@@ -596,6 +628,16 @@ class _Llama2Tokenizer(_SentencePieceTokenizer): t = t + [self.eos_id] return t @@ -73,7 +76,7 @@ index 5cf222cc..35ad2b85 100644 def detokenize(self, ids): return self.tokenizer.decode_ids(ids) -@@ -909,3 +948,276 @@ class _NullMultimodalTokenizer(MegatronTokenizer): +@@ -909,3 +951,286 @@ class _NullMultimodalTokenizer(MegatronTokenizer): @property def additional_special_tokens_ids(self): return None @@ -196,6 +199,16 @@ index 5cf222cc..35ad2b85 100644 + self.pad_id = self.tokenizer.encode('<|endoftext|>')[0] + + ++class _SmolLM2TokenizerFS(_HFTokenizerFS): ++ """Adapted SmolLM2 tokenizer.""" ++ ++ def __init__(self, tokenizer_path): ++ super().__init__(tokenizer_path) ++ self.eod_id = self.tokenizer.encode('<|endoftext|>')[0] ++ self.cls_id = self.tokenizer.encode('<|endoftext|>')[0] ++ self.pad_id = self.tokenizer.encode('<|endoftext|>')[0] ++ ++ +class _HFTokenizersTokenizerFS(MegatronTokenizer): + """Tokenizer from HuggingFace Tokenizers.""" + diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py index 3a7bffffa..97a60953e 100644 --- a/tools/checkpoint/convert.py +++ b/tools/checkpoint/convert.py @@ -39,7 +39,7 @@ def main(): default=[], nargs="+", required=True, - choices=["aquila", "mistral", "mixtral", "llama", "deepseek_v3", "qwen3"], + choices=["aquila", "mistral", "mixtral", "llama", "deepseek_v3", "qwen3", "smollm2"], help="Type of the model.", ) parser.add_argument( diff --git a/tools/checkpoint/smollm2/args.py b/tools/checkpoint/smollm2/args.py new file mode 100644 index 000000000..7c2056f5e --- /dev/null +++ b/tools/checkpoint/smollm2/args.py @@ -0,0 +1,94 @@ +import json +import os + + +def load_args_hf2mg(args): + + # Read llama args. + llama_args_path = os.path.join(args.load, "config.json") + with open(llama_args_path) as f: + llama_args = json.load(f) + + # Update Megatron args. + args.attention_dropout = llama_args["attention_dropout"] + args.hidden_dropout = llama_args["attention_dropout"] + args.hidden_size = llama_args["hidden_size"] + args.swiglu = llama_args["hidden_act"] == "silu" + args.init_method_std = llama_args["initializer_range"] + args.ffn_hidden_size = llama_args["intermediate_size"] + args.max_position_embeddings = llama_args["max_position_embeddings"] + args.model_type = llama_args["model_type"] + args.num_attention_heads = llama_args["num_attention_heads"] + args.num_layers = llama_args["num_hidden_layers"] + args.num_query_groups = llama_args["num_key_value_heads"] + args.norm_epsilon = llama_args["rms_norm_eps"] + args.rotary_seq_len_interpolation_factor = ( + None if llama_args["rope_scaling"] == "null" else llama_args["rope_scaling"] + ) + args.rotary_base = llama_args["rope_theta"] + args.untie_embeddings_and_output_weights = not llama_args["tie_word_embeddings"] + args.bf16 = llama_args["dtype"] == "bfloat16" + args.fp16 = llama_args["dtype"] == "float16" + args.vocab_size = llama_args["vocab_size"] + args.padded_vocab_size = llama_args["vocab_size"] + + args.seq_length = 2048 + args.global_batch_size = 1024 + args.iteration = 1 # '0', 'release' don't work + args.add_position_embedding = False + args.group_query_attention = True + args.normalization = "RMSNorm" + args.use_rotary_position_embeddings = True + args.add_bias_linear = False + args.add_qkv_bias = False + args.make_vocab_size_divisible_by = 64 + args.consumed_train_samples = 0 + args.consumed_valid_samples = 0 + args.norm_has_bias = False + args.tokenizer_type = "Llama3TokenizerFS" + + +def save_args_mg2hf(args): + from transformers import LlamaConfig + + config = LlamaConfig( + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + intermediate_size=args.ffn_hidden_size, + num_hidden_layers=args.encoder_num_layers, + num_attention_heads=args.num_attention_heads, + num_key_value_heads=args.num_query_groups, + hidden_act="silu" if args.swiglu else False, + max_position_embeddings=args.max_position_embeddings, + initializer_range=args.init_method_std, + rms_norm_eps=args.norm_epsilon, + use_cache=True, + tie_word_embeddings=not args.untie_embeddings_and_output_weights, + rope_theta=args.rotary_base, + rope_scaling=args.rotary_seq_len_interpolation_factor, + attention_bias=args.add_qkv_bias, + attention_dropout=args.attention_dropout, + torch_dtype=args.params_dtype, + bias_dropout_fusion=args.bias_dropout_fusion, + end_weight_decay=args.end_weight_decay, + global_batch_size=args.global_batch_size, + hidden_dropout=args.hidden_dropout, + lr=args.lr, + lr_decay_style=args.lr_decay_style, + make_vocab_size_divisible_by=args.make_vocab_size_divisible_by, + masked_softmax_fusion=args.masked_softmax_fusion, + min_lr=args.min_lr, + norm_init_weight=args.norm_init_weight, + perform_initialization=args.perform_initialization, + reset_attention_mask=args.reset_attention_mask, + reset_position_ids=args.reset_position_ids, + rotary_base=args.rotary_base, + seed=args.seed, + split=args.split, + start_weight_decay=args.start_weight_decay, + use_flash_attn=args.use_flash_attn, + weight_decay_incr_style=args.weight_decay_incr_style, + ) + config.save_pretrained(args.save) + + return config diff --git a/tools/checkpoint/smollm2/ckpt.py b/tools/checkpoint/smollm2/ckpt.py new file mode 100644 index 000000000..efabe0856 --- /dev/null +++ b/tools/checkpoint/smollm2/ckpt.py @@ -0,0 +1,209 @@ +import sys + +import torch + +sys.path.append("..") +from mixtral.ckpt import ( + get_embedding_ckpt, + get_final_norm_ckpt, + get_hf_attn_ckpt, + get_output_layer_ckpt, + set_embedding_ckpt, + set_final_norm_ckpt, + set_hf_attn_ckpt, + set_hf_embedding_ckpt, + set_hf_final_norm_ckpt, + set_hf_output_layer_ckpt, + set_output_layer_ckpt, +) + + +def get_hf_mlp_ckpt(message, model, layer_id, args): + assert args.swiglu is True + + tf_layer = model.model.layers[layer_id] + message["mlp l0 weight W"] = tf_layer.mlp.gate_proj.weight.data + message["mlp l0 weight V"] = tf_layer.mlp.up_proj.weight.data + message["mlp l1 weight"] = tf_layer.mlp.down_proj.weight.data + + if args.add_bias_linear: + message["mlp l0 bias W"] = tf_layer.mlp.gate_proj.bias.data + message["mlp l0 bias V"] = tf_layer.mlp.up_proj.bias.data + message["mlp l1 bias"] = tf_layer.mlp.down_proj.bias.data + + +def set_hf_mlp_ckpt(message, model, layer_id, md, args): + assert args.swiglu is True + + tf_layer = model.model.layers[layer_id] + tf_layer.mlp.gate_proj.weight.data.copy_(message.pop("mlp l0 weight W")) + tf_layer.mlp.up_proj.weight.data.copy_(message.pop("mlp l0 weight V")) + tf_layer.mlp.down_proj.weight.data.copy_(message.pop("mlp l1 weight")) + + if md.add_bias_linear: + tf_layer.mlp.gate_proj.bias.data.copy_(message.pop("mlp l0 bias W")) + tf_layer.mlp.up_proj.bias.data.copy_(message.pop("mlp l0 bias V")) + tf_layer.mlp.down_proj.bias.data.copy_(message.pop("mlp l1 bias")) + + +def _get_parallel_size(args): + assert args.expert_model_parallel_size == 1 + return ( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.expert_model_parallel_size, + args.virtual_pipeline_model_parallel_size or 1, + ) + + +def get_attn_ckpt(message, models, layer_id, args): + tp_size, _, _, _ = _get_parallel_size(args) + + # parallel tensor + qkv_weight = [] + qkv_bias = [] + proj_weight = [] + # non-parallel tensor + proj_bias = None + input_norm_weight = None + input_norm_bias = None + post_norm_weight = None + post_norm_bias = None + + assert len(models) == tp_size + for model in models: + tf_layer = model.decoder.layers[layer_id] + # weight + qkv_weight.append(tf_layer.self_attention.linear_qkv.weight.data) + proj_weight.append(tf_layer.self_attention.linear_proj.weight.data) + input_norm_weight = tf_layer.self_attention.linear_qkv.layer_norm_weight.data + post_norm_weight = tf_layer.mlp.linear_fc1.layer_norm_weight.data + # bias + if args.norm_has_bias: + input_norm_bias = tf_layer.self_attention.linear_qkv.layer_norm_bias.data + post_norm_bias = tf_layer.mlp.linear_fc1.layer_norm_bias.data + if args.add_qkv_bias or args.add_bias_linear: + qkv_bias.append(tf_layer.self_attention.linear_qkv.bias.data) + if args.add_bias_linear: + proj_bias = tf_layer.self_attention.linear_proj.bias.data + + # weight + message["qkv weight"] = torch.cat(qkv_weight, dim=0) + message["proj weight"] = torch.cat(proj_weight, dim=1) + message["input norm weight"] = input_norm_weight + message["post norm weight"] = post_norm_weight + # bias + if args.norm_has_bias: + message["input norm bias"] = input_norm_bias + message["post norm bias"] = post_norm_bias + if args.add_qkv_bias or args.add_bias_linear: + message["qkv bias"] = torch.cat(qkv_bias, dim=0) + if args.add_bias_linear: + message["proj bias"] = proj_bias + + +def get_mlp_ckpt(message, models, layer_id, args): + tp_size, _, _, _ = _get_parallel_size(args) + + # parallel tensor + l0_weight = [] + l0_bias = [] + l1_weight = [] + # non-parallel tensor + l1_bias = None + + assert len(models) == tp_size + for model in models: + tf_layer = model.decoder.layers[layer_id] + # weight + l0_weight.append(tf_layer.mlp.linear_fc1.weight.data) + l1_weight.append(tf_layer.mlp.linear_fc2.weight.data) + # bias + if args.add_bias_linear: + l0_bias.append(tf_layer.mlp.linear_fc1.bias.data) + l1_bias = tf_layer.mlp.linear_fc2.bias.data + + # weight + message["mlp l1 weight"] = torch.cat(l1_weight, dim=1) + if args.swiglu: + for tp_rank in range(tp_size): + l0_weight[tp_rank] = torch.chunk(l0_weight[tp_rank], 2, dim=0) + message["mlp l0 weight W"] = torch.cat([w[0] for w in l0_weight], dim=0) + message["mlp l0 weight V"] = torch.cat([w[1] for w in l0_weight], dim=0) + else: + message["mlp l0 weight"] = torch.cat(l0_weight, dim=0) + # bias + if args.add_bias_linear: + message["mlp l1 bias"] = l1_bias + if args.swiglu: + for tp_rank in range(tp_size): + l0_bias[tp_rank] = torch.chunk(l0_bias[tp_rank], 2, dim=0) + message["mlp l0 bias W"] = torch.cat([b[0] for b in l0_bias], dim=0) + message["mlp l0 bias V"] = torch.cat([b[1] for b in l0_bias], dim=0) + else: + message["mlp l0 bias"] = torch.cat(l0_bias, dim=0) + + +def set_attn_ckpt(message, models, layer_id, md, args): + tp_size, _, _, _ = _get_parallel_size(args) + + # weight + qkv_weight = torch.chunk(message.pop("qkv weight"), tp_size, dim=0) + proj_weight = torch.chunk(message.pop("proj weight"), tp_size, dim=1) + input_norm_weight = message.pop("input norm weight") + post_norm_weight = message.pop("post norm weight") + # bias + if md.norm_has_bias: + input_norm_bias = message.pop("input norm bias") + post_norm_bias = message.pop("post norm bias") + if md.add_qkv_bias or md.add_bias_linear: + qkv_bias = torch.chunk(message.pop("qkv bias"), tp_size, dim=0) + if md.add_bias_linear: + proj_bias = message.pop("proj bias") + + # set data to transformer layer's self-attention + for tp_rank, model in enumerate(models): + tf_layer = model.decoder.layers[layer_id] + tf_layer.self_attention.linear_qkv.weight.data.copy_(qkv_weight[tp_rank]) + tf_layer.self_attention.linear_proj.weight.data.copy_(proj_weight[tp_rank]) + tf_layer.self_attention.linear_qkv.layer_norm_weight.data.copy_(input_norm_weight) + tf_layer.mlp.linear_fc1.layer_norm_weight.data.copy_(post_norm_weight) + if md.norm_has_bias: + tf_layer.self_attention.linear_qkv.layer_norm_bias.data.copy_(input_norm_bias) + tf_layer.mlp.linear_fc1.layer_norm_bias.data.copy(post_norm_bias) + if md.add_qkv_bias or md.add_bias_linear: + tf_layer.self_attention.linear_qkv.bias.data.copy_(qkv_bias[tp_rank]) + if md.add_bias_linear: + tf_layer.self_attention.linear_proj.bias.data.copy_(proj_bias) + + +def set_mlp_ckpt(message, models, layer_id, md, args): + tp_size, _, _, _ = _get_parallel_size(args) + + # weight + l1_weight = torch.chunk(message.pop("mlp l1 weight"), tp_size, dim=1) + if md.swiglu: + l0_weight_W = torch.chunk(message.pop("mlp l0 weight W"), tp_size, dim=0) + l0_weight_V = torch.chunk(message.pop("mlp l0 weight V"), tp_size, dim=0) + l0_weight = [torch.cat(weights, dim=0) for weights in zip(l0_weight_W, l0_weight_V)] + else: + l0_weight = torch.chunk(message.pop("mlp l0 weight"), tp_size, dim=0) + # bias + if md.add_bias_linear: + l1_bias = message.pop("mlp l1 bias") + if md.swiglu: + l0_bias_W = torch.chunk(message.pop("mlp l0 bias W"), tp_size, dim=0) + l0_bias_V = torch.chunk(message.pop("mlp l0 bias V"), tp_size, dim=0) + l0_bias = [torch.cat(bias, dim=0) for bias in zip(l0_bias_W, l0_bias_V)] + else: + l0_bias = torch.chunk(message.pop("mlp l0 bias"), tp_size, dim=0) + + # set data to transformer layer for mlp + for tp_rank, model in enumerate(models): + tf_layer = model.decoder.layers[layer_id] + tf_layer.mlp.linear_fc1.weight.data.copy_(l0_weight[tp_rank]) + tf_layer.mlp.linear_fc2.weight.data.copy_(l1_weight[tp_rank]) + + if md.add_bias_linear: + tf_layer.mlp.linear_fc1.bias.data.copy_(l0_bias[tp_rank]) + tf_layer.mlp.linear_fc2.bias.data.copy_(l1_bias) diff --git a/tools/checkpoint/smollm2/model.py b/tools/checkpoint/smollm2/model.py new file mode 100644 index 000000000..0a9e02bf3 --- /dev/null +++ b/tools/checkpoint/smollm2/model.py @@ -0,0 +1,4 @@ +import sys + +sys.path.append("..") +from mixtral.model import *