diff --git a/examples/llama/load_weights.py b/examples/llama/load_weights.py new file mode 100644 index 000000000..0c5bd3fee --- /dev/null +++ b/examples/llama/load_weights.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import json +import torch + +from typing import Callable, Dict, Optional + + +def load_weights( + stage_module: torch.nn.Module, + weight_index_file: Optional[str] = "pytorch_model.bin.index.json", +): + """ + Load weights from Hugging Face checkpoints into a stage module. + + This is a utility for Hugging Face ModelHub checkpoints that comes with an + index file and multiple binary files. The index file indicates which + parameter is saved in which binary. An example can be found at: + https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main + + Please download the following files in the same directory as this script: + - pytorch_model.bin.index.json + - pytorch_model-00001-of-00002.bin + - pytorch_model-00002-of-00002.bin + """ + + state_dict = stage_module.state_dict() + updated_states = dict() + + # Get the weight map -- a map from parameter name to file it is saved in + f = open(weight_index_file) + js = json.load(f) + weight_map = js["weight_map"] + + # Figure the set of binary files we'd need to open in order to fill the + # state dict of the stage module. It will be a subset of all the binary + # files because the stage module is a partition of the full model. + needed_files = set() + for param in state_dict.keys(): + # The file a param is saved in + file = weight_map.setdefault(param, None) + if file: + needed_files.add(file) + + # Now we load the needed binary files + for file in needed_files: + checkpoint = torch.load(file, weights_only=True) + for param in state_dict.keys(): + file_having_param = weight_map[param] + if file_having_param is None: + print(f"Cannot find checkpoint file for {param}, skipping") + elif file_having_param == file: + state_dict[param] = checkpoint[param] + updated_states.setdefault(param, None) + + # Check if the module's state dict will be fully updated from checkpoint + if state_dict.keys() == updated_states.keys(): + print("Fully updated state dict") + else: + print("Partially updated state dict") + + # Now load the weights into the stage module + # We use `assign=True` because otherwise the properties of the tensors in + # the current module are preserved. + stage_module.load_state_dict(state_dict, assign=True) + + +def init_buffers( + stage_module: torch.nn.Module, + init_callbacks: Dict[str, Callable], + device: torch.device, + dtype: Optional[torch.dtype] = None, +): + """ + Initialize buffers of `stage_module` per the callback in `init_callbacks`. + `init_callbacks` is a dictionary from a buffer's FQN to its init function. + """ + for name, buf in stage_module.named_buffers(): + if name in init_callbacks: + cb = init_callbacks[name] + buf_val = cb(device) + if dtype: + buf_val = buf_val.to(dtype) + # Find the parent module + splits = name.split(".") + mod = stage_module + for atom in splits[: -1]: + mod = getattr(mod, atom) + mod.register_buffer( + splits[-1], buf_val, persistent=False, + ) + print(f"Initialized buffer {name}, {buf_val.dtype}, {buf_val.device}") + diff --git a/examples/llama/meta_init.py b/examples/llama/meta_init.py new file mode 100644 index 000000000..605d8a958 --- /dev/null +++ b/examples/llama/meta_init.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +""" +This script shows how to create llama model in "meta" device mode, partition it +into pipeline stages, and materialize each stage modules from Hugging Face +checkpoints. + +Before running the script, please download the following files in the same +directory as this script: +- pytorch_model.bin.index.json +- pytorch_model-00001-of-00002.bin +- pytorch_model-00002-of-00002.bin + +Download link: +https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main + +How to run this script: +$ python meta_init.py + +I haven't used a distributed runtime, because I only want to showcase how to +load each stage module. Feel free to modify the script to run in a distributed +way by distributing the for loop at [Note 3]. +""" + +import os +import torch +from torch.distributed.pipelining import pipeline, SplitPoint +from torch._subclasses.fake_tensor import FakeTensorMode +from transformers import AutoModelForCausalLM, AutoTokenizer + +from load_weights import load_weights, init_buffers + +# Grab the model in meta/fake mode +fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + +with torch.device("meta"): + llama = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-chat-hf" + ) + +llama.eval() +print(llama) + +# Cast the model to FakeTensor with real device (from meta device) because +# there is autocast code in llama. Autocast functions based on device of +# tensor. So we'd need to give it a real device instead of meta device. +with fake_mode: + # [Note 1]: set device to "cuda" if you are using GPUs + llama.to_empty(device="cpu") + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +tokenizer.pad_token = tokenizer.eos_token +prompts = ( + "How do you", "I like to", +) + +inputs = tokenizer(prompts, return_tensors="pt", padding=True) +real_ids = inputs["input_ids"] +# The example input needs to FakeTensor too +fake_ids = fake_mode.from_tensor(real_ids) + +# Beginning of distributed +# [Note 2]: change world size here +world_size = 2 +print(f"{world_size=}") + +# Cut model by equal number of layers per rank +layers_per_rank = llama.config.num_hidden_layers // world_size +print(f"layers_per_rank = {layers_per_rank}") +split_spec = { + f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, world_size) +} + +# Convert model into a pipeline +pipe = pipeline( + llama, + mb_args=(fake_ids,), + mb_kwargs={"output_attentions": False, "output_hidden_states": False, "use_cache": False,}, + split_spec=split_spec, +) + +# Materialize each stage +# [Note 3]: remove this for loop if you are running this script in a +# distributed manner +for rank in range(world_size): + stage_module = pipe.get_stage_module(rank) + print(f"Loading weights into stage {rank}") + load_weights(stage_module) + if hasattr(llama, "buf_init_callbacks"): + init_buffers(stage_module, llama.buf_init_callbacks, "cpu", torch.float16) + stage_module.print_readable() +