diff --git a/tools/checkpoint_saver_megatron.py b/tools/checkpoint_saver_megatron.py index 03b7ec84c1..71cdf9efb7 100644 --- a/tools/checkpoint_saver_megatron.py +++ b/tools/checkpoint_saver_megatron.py @@ -162,13 +162,15 @@ def check_message(msg): setattr(margs, arg, value) validate_args(margs) - + margs.ckpt_transfer = True + if args.tokenizer_model: + margs.tokenizer_model = args.tokenizer_model set_global_variables(margs) # margs = megatron args margs = get_args() - margs.ckpt_transfer = True + print("args.tokenizer_model", args.tokenizer_model) if hasattr(md, 'consumed_train_samples'): margs.consumed_train_samples = md.consumed_train_samples margs.consumed_valid_samples = md.consumed_valid_samples diff --git a/tools/convert_checkpoint/README.md b/tools/convert_checkpoint/README.md index 3f74bb1aa4..af09947cbb 100644 --- a/tools/convert_checkpoint/README.md +++ b/tools/convert_checkpoint/README.md @@ -76,3 +76,24 @@ cd /hf/transformers python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py \ /path/to/Megatron/checkpoint/iter_0097500/mp_rank_00/model_optim_rng.pt ``` + +## HF Transformers to Megatron-DeepSpeed (currently only support LLama) + +In order to convert llama model from HF Transformers to Megatron-DeepSpeed, you can do this by two steps: + +```bash +# 1. Convert llama weight from hf to megatron +python tools/convert_checkpoint/transformers_to_megatron_llama.py \ +--out=/path/to/Megatron-Deepspeed/checkpoint/ \ +--cache-dir=/path/to/hf/transformers/llama_checkpoint + +# 2. Convert Megatron-DeepSpeed checkpoint to distributed version + python3 tools/checkpoint_util.py \ + --target-tensor-parallel-size 4 \ + --target-pipeline-parallel-size 2 \ + --load-dir /path/to/Megatron-Deepspeed/checkpoint/ \ + --save-dir /path/to/Megatron-Deepspeed/distribute_checkpoint/ \ + --model-type GPT +``` + + diff --git a/tools/convert_checkpoint/weights2megatron/merge_llama.py b/tools/convert_checkpoint/merge_llama.py similarity index 83% rename from tools/convert_checkpoint/weights2megatron/merge_llama.py rename to tools/convert_checkpoint/merge_llama.py index 6d7f2cd5ea..6235211a21 100644 --- a/tools/convert_checkpoint/weights2megatron/merge_llama.py +++ b/tools/convert_checkpoint/merge_llama.py @@ -86,12 +86,9 @@ def merge_meta_llama(size: int, root_dir: Path): return merged_ckpt -def merge_hf_llama(size: int, version: int, cache_dir: Optional[Path] = None, model_path=None, tokenizer_len=32000): - assert version == 2, "Only llama v2 available using huggingface" - print(cache_dir) +def merge_hf_llama(cache_dir: Optional[Path] = None): + # assert version == 2, "Only llama v2 available using huggingface" model = LlamaForCausalLM.from_pretrained(cache_dir, cache_dir=cache_dir, local_files_only=True, use_safetensors=False) - # resize token embeddings size according saved tokenizer for model extend token size. - # model.resize_token_embeddings(tokenizer_len) weights = model.state_dict() weights["tok_embeddings.weight"] = weights.pop("model.embed_tokens.weight") weights["norm.weight"] = weights.pop("model.norm.weight") @@ -110,12 +107,5 @@ def merge_hf_llama(size: int, version: int, cache_dir: Optional[Path] = None, mo "post_attention_layernorm": "ffn_norm" }[rmatch.group(2)] weights[rmatch.group(1) + new_key + rmatch.group(3)] = weights.pop(key) - return weights + return weights, model.config - -def merge_llama(size: int, version: int, root_dir: Optional[Path] = None, tokenizer_len: Optional[int] = 32000): - if root_dir is not None and (root_dir/"consolidated.00.pth").exists(): - return merge_meta_llama(size, root_dir), "meta" - print(f"Weights at {root_dir} do not look like a meta checkpoint, assuming " - "huggingface cache_dir instead") - return merge_hf_llama(size, version, root_dir, tokenizer_len), "hf" diff --git a/tools/convert_checkpoint/weights2megatron/permute_qkv.py b/tools/convert_checkpoint/permute_qkv.py similarity index 100% rename from tools/convert_checkpoint/weights2megatron/permute_qkv.py rename to tools/convert_checkpoint/permute_qkv.py diff --git a/tools/convert_checkpoint/weights2megatron/weights2megatron_llama.py b/tools/convert_checkpoint/transformers_to_megatron_llama.py similarity index 60% rename from tools/convert_checkpoint/weights2megatron/weights2megatron_llama.py rename to tools/convert_checkpoint/transformers_to_megatron_llama.py index 4b5d56448b..65a94f2406 100644 --- a/tools/convert_checkpoint/weights2megatron/weights2megatron_llama.py +++ b/tools/convert_checkpoint/transformers_to_megatron_llama.py @@ -8,24 +8,14 @@ import torch from tqdm.auto import trange from transformers import AutoModelForCausalLM, LlamaTokenizer +from transformers import LlamaConfig from permute_qkv import permute_qkv -from merge_llama import merge_llama -from transformers import AutoTokenizer +from merge_llama import merge_hf_llama -llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80} -llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64} -llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016, - 70: 28672} # should be (2/3)*4*d, but it isn't exaclty that -llama_s2hidden = {7: 4096, 13: 5120, 30: 6656, 65: 8192, 70: 8192} - - -def llama_to_megatron(weights: dict, size: int, source: str = "meta", - version: int = 1) -> dict: +def llama_to_megatron(weights: dict, llama_config: LlamaConfig = None) -> dict: def permute(qkv_w): - if source == "hf": - return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads) - return qkv_w + return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads) def rearrange_qkv(wq, wk, wv): wq = torch.split(wq, n_hidden_per_head, dim=0) @@ -42,12 +32,11 @@ def rearrange_qkv(wq, wk, wv): return permute(torch.concat(w_qkv)) # config - n_layer = llama_s2layer[size] - hidden = llama_s2hidden[size] - n_heads = llama_s2heads[size] + n_layer = llama_config.num_hidden_layers + hidden = llama_config.hidden_size + n_heads = llama_config.num_attention_heads n_hidden_per_head = hidden//n_heads - n_kv_heads = n_heads if version == 1 or size <= 13 else 8 - + n_kv_heads = llama_config.num_key_value_heads # weights independent of layers embedding = {"word_embeddings": {"weight": weights["tok_embeddings.weight"]}} transformer = {"final_layernorm.weight": weights["norm.weight"]} @@ -86,32 +75,34 @@ def rearrange_qkv(wq, wk, wv): return {"embedding": embedding, "encoder": transformer, "lm_head": lm_head} -def main(model_name: str = "llama2", size: int = 7, out: Optional[Path] = None, - cache_dir: Optional[Path] = None, megatron_path: Optional[Path] = None, padded_vocab_size: Optional[int] = 32000): +def main(out: Optional[Path] = None, + cache_dir: Optional[Path] = None, megatron_path: Optional[Path] = None): + if megatron_path: + print("Add megatron to os path") + os.path.append(megatron_path) # get weights from or specified directory print("Getting llama...") - version = 2 if "2" in model_name else 1 - hf_weights, llama_source = merge_llama(size, version, cache_dir, padded_vocab_size) + hf_weights, llama_config = merge_hf_llama(cache_dir) # convert state dict to be megatron-compatible - megatron_weights = llama_to_megatron(hf_weights, size, llama_source, - version=1 if model_name == "llama" else 2) + megatron_weights = llama_to_megatron(hf_weights, llama_config=llama_config) # set args # llama1, llama2 - args = {"num_layers": llama_s2layer[size], - "hidden_size": llama_s2hidden[size], - "num_attention_heads": llama_s2heads[size], - "ffn_hidden_size": llama_s2dense[size], - "num_key_value_heads": llama_s2heads[size], + args = {"num_layers": llama_config.num_hidden_layers, + "hidden_size": llama_config.hidden_size, + "num_attention_heads": llama_config.num_attention_heads, + "ffn_hidden_size": llama_config.intermediate_size, + "num_key_value_heads": llama_config.num_key_value_heads, "parallel_attn": False, "make_vocab_size_divisible_by": 1, "glu_activation": "swiglu", + "max_position_embeddings": llama_config.max_length, # should use max_length rather than max_position_embeddings, detail in https://github.com/lm-sys/FastChat/issues/2046#issuecomment-1645265800 + "seq_length": llama_config.max_length, + "layernorm_epsilon": llama_config.rms_norm_eps, # llama args - "padded_vocab_size": padded_vocab_size, - "use_rms_norm": True, - "tie_embed_logits": False, + "padded_vocab_size": llama_config.vocab_size, "tokenizer_type": "GPTSentencePieceTokenizer", "no-query-key-layer-scaling": True, "attention-dropout": 0, @@ -124,19 +115,13 @@ def main(model_name: str = "llama2", size: int = 7, out: Optional[Path] = None, "add_position_embedding": False, "add_bias_linear": False, } - if model_name == "llama": - args.update({"max_position_embeddings": 2048, "seq_length": 2048, - "layernorm_epsilon": 1e-6}) - else: # llama2 - args.update({"max_position_embeddings": 2048, "seq_length": 2048, - "layernorm_epsilon": 1e-5}) - if size >= 34: - args.update({"num_attention_heads_kv": 8}) + if llama_config.num_key_value_heads: + args.update({"num_attention_heads_kv": llama_config.num_key_value_heads}) args.update({ "tensor_model_parallel_size": 1, "pipeline_model_parallel_size": 1, - "iteration": "release", + "iteration": 0, "bias_gelu_fusion": False, "bias_droput_fusion": False, }) @@ -145,42 +130,31 @@ def main(model_name: str = "llama2", size: int = 7, out: Optional[Path] = None, (out/"release"/"mp_rank_00").mkdir(parents=True) with open(out/"latest_checkpointed_iteration.txt", "w+") as f: f.write("release") - final_dict = {"iteration": "release", "model": {"language_model": megatron_weights}, + final_dict = {"iteration": 'release', "model": {"language_model": megatron_weights}, "checkpoint_version": 3.0, "args": Namespace(**args)} torch.save(final_dict, out/"release"/"mp_rank_00"/"model_optim_rng.pt") print("Saved weights in", out) - if model_name == "llama2" and llama_source == "hf": - tokenizer = LlamaTokenizer.from_pretrained( - cache_dir, cache_dir=cache_dir, local_files_only=True, - ) - token_path = out/"tokenizer.model" - vocab_file = tokenizer.vocab_file - shutil.copy(vocab_file, token_path) - print("Saved tokenizer.model in", token_path) + tokenizer = LlamaTokenizer.from_pretrained( + cache_dir, cache_dir=cache_dir, local_files_only=True, + ) + token_path = out/"tokenizer.model" + vocab_file = tokenizer.vocab_file + shutil.copy(vocab_file, token_path) + print("Saved tokenizer.model in", token_path) print("Done") if __name__ == "__main__": - parser = ArgumentParser(description="Convert Huggingface falcon weights to " + parser = ArgumentParser(description="Convert Huggingface llama weights to " "megatron-compatible weights") - parser.add_argument("model", choices={"falcon", "llama", "llama2"}) - parser.add_argument("--size", default=7, choices={7, 13, 30, 34, 40, 65, 70}, type=int, - help="The size of the model") parser.add_argument("--out", type=Path, help="Directory to store the megatron weights (as checkpoint)") parser.add_argument("--cache-dir", type=Path, help=("Directory to store the huggingface weights, or " "in case of the llama model, where to look for " "the consolidated.xx.pth")) - parser.add_argument("--megatron-path", type=Path, + parser.add_argument("--megatron-path", type=Path, default=None, help="Path where to find megatron code") - parser.add_argument("--tokenizer-size", type=int, help="Directory to store the megatron weights (as checkpoint)", default=None) args = parser.parse_args() - # small arg verification - if args.model == "llama": - assert args.size in {7, 13, 30, 65} - else: - assert args.size in {7, 13, 70} - - main(args.model, args.size, args.out, args.cache_dir, args.megatron_path, args.tokenizer_size) + main(args.out, args.cache_dir, args.megatron_path) diff --git a/tools/convert_checkpoint/weights2megatron/README.md b/tools/convert_checkpoint/weights2megatron/README.md deleted file mode 100644 index d86bb759b9..0000000000 --- a/tools/convert_checkpoint/weights2megatron/README.md +++ /dev/null @@ -1,19 +0,0 @@ -# Introduction -This folder is a collection of scripts for converting hf checkpoints to megatron-DeepSpeed checkpoints. - -# Usage -## huggingface to megatron -```bash -python weights2megatron/weights2megatron.py llama2 --size=13 --out=${DEST_DIR} --cache-dir=${HF_CKPT_DIR} --tokenizer-size=32000 -``` - -## split ckpt by TP and PP size -```bash - python3 tools/checkpoint_util.py \ - --target-tensor-parallel-size 4 \ - --target-pipeline-parallel-size 2 \ - --load-dir ${LOAD_DIR} \ - --save-dir ${SAVE_DIR} \ - --model-type GPT \ - --true-vocab-size 32000 -``` \ No newline at end of file