From 1525384d7e7cd3642b37e6985d223fd84e171fef Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 23 Jul 2024 08:18:35 -0700 Subject: [PATCH 1/4] Meta init llama then pipeline then materialize --- examples/llama/load_weights.py | 61 ++++++++++++++++++++++++ examples/llama/meta_init.py | 84 ++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 examples/llama/load_weights.py create mode 100644 examples/llama/meta_init.py diff --git a/examples/llama/load_weights.py b/examples/llama/load_weights.py new file mode 100644 index 000000000..f79ef2740 --- /dev/null +++ b/examples/llama/load_weights.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import json +import torch + +from typing import Optional + + +def load_weights( + stage_module: torch.nn.Module, + weight_index_file: Optional[str] = "pytorch_model.bin.index.json", +): + """ + Load weights from Hugging Face checkpoints into a stage module. + + This is a utility for Hugging Face ModelHub checkpoints that comes with an + index file and multiple binary files. The index file indicates which + parameter is saved in which binary. An example can be found at: + https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main + + Please download the following files in the same directory as this script: + - pytorch_model.bin.index.json + - pytorch_model-00001-of-00002.bin + - pytorch_model-00002-of-00002.bin + """ + + state_dict = stage_module.state_dict() + updated_states = dict() + + # Get the weight map -- a map from parameter name to file it is saved in + f = open(weight_index_file) + js = json.load(f) + weight_map = js["weight_map"] + + # Figure the set of binary files we'd need to open in order to fill the + # state dict of the stage module. It will be a subset of all the binary + # files because the stage module is a partition of the full model. + needed_files = set() + for param in state_dict.keys(): + file = weight_map[param] + needed_files.add(file) + + # Now we load the needed binary files + for file in needed_files: + checkpoint = torch.load(file, weights_only=True) + for param in state_dict.keys(): + if weight_map[param] == file: + state_dict[param] = checkpoint[param] + updated_states.setdefault(param, None) + + # Check if the module's state dict will be fully updated from checkpoint + if state_dict.keys() == updated_states.keys(): + print("Fully updated state dict") + else: + print("Partially updated state dict") + + # Now load the weights into the stage module + # We use `assign=True` because otherwise the properties of the tensors in + # the current module are preserved. + stage_module.load_state_dict(state_dict, assign=True) + diff --git a/examples/llama/meta_init.py b/examples/llama/meta_init.py new file mode 100644 index 000000000..e3a7a98f6 --- /dev/null +++ b/examples/llama/meta_init.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +""" +This script shows how to create llama model in "meta" device mode, partition it +into pipeline stages, and materialize each stage modules from Hugging Face +checkpoints. + +Before running the script, please download the following files in the same +directory as this script: +- pytorch_model.bin.index.json +- pytorch_model-00001-of-00002.bin +- pytorch_model-00002-of-00002.bin + +Download link: +https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main + +How to run this script: +$ python meta_init.py + +I haven't used a distributed runtime, because I only want to showcase how to +load each stage module. Feel free to modify the script to run in a distributed +way by distributing the for loop at [Note 3]. +""" + +import os +import torch +from torch.distributed.pipelining import pipeline, SplitPoint +from torch._subclasses.fake_tensor import FakeTensorMode +from transformers import AutoModelForCausalLM, AutoTokenizer + +from load_weights import load_weights + +# Grab the model in meta/fake mode +fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + +with torch.device("meta"): + llama = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-chat-hf" + ) + +print(llama) + +# Cast the model to FakeTensor with real device (from meta device) because +# there is autocast code in llama. Autocast functions based on device of +# tensor. So we'd need to give it a real device instead of meta device. +with fake_mode: + # [Note 1]: set device to "cuda" if you are using GPUs + llama.to_empty(device="cpu") + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") +tokenizer.pad_token = tokenizer.eos_token +prompts = ( + "How do you", "I like to", +) + +inputs = tokenizer(prompts, return_tensors="pt", padding=True) +real_ids = inputs["input_ids"] +# The example input needs to FakeTensor too +fake_ids = fake_mode.from_tensor(real_ids) + +# Beginning of distributed +# [Note 2]: change world size here +world_size = 4 +print(f"{world_size=}") + +# Cut model by equal number of layers per rank +layers_per_rank = llama.config.num_hidden_layers // world_size +print(f"layers_per_rank = {layers_per_rank}") +split_spec = { + f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, world_size) +} + +# Convert model into a pipeline +pipe = pipeline(llama, mb_args=(fake_ids,), split_spec=split_spec) + +# Materialize each stage +# [Note 3]: remove this for loop if you are running this script in a +# distributed manner +for rank in range(world_size): + stage_module = pipe.get_stage_module(rank) + print(f"Loading weights into stage {rank}") + load_weights(stage_module) + From 59589286909d7cb46a17dcf482fb008905044ebe Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Wed, 24 Jul 2024 02:12:18 -0700 Subject: [PATCH 2/4] Add kwargs to mute many outputs; change world_size to 2 --- examples/llama/meta_init.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/llama/meta_init.py b/examples/llama/meta_init.py index e3a7a98f6..629e9a60f 100644 --- a/examples/llama/meta_init.py +++ b/examples/llama/meta_init.py @@ -38,6 +38,7 @@ "meta-llama/Llama-2-7b-chat-hf" ) +llama.eval() print(llama) # Cast the model to FakeTensor with real device (from meta device) because @@ -60,7 +61,7 @@ # Beginning of distributed # [Note 2]: change world size here -world_size = 4 +world_size = 2 print(f"{world_size=}") # Cut model by equal number of layers per rank @@ -72,7 +73,12 @@ } # Convert model into a pipeline -pipe = pipeline(llama, mb_args=(fake_ids,), split_spec=split_spec) +pipe = pipeline( + llama, + mb_args=(fake_ids,), + mb_kwargs={"output_attentions": False, "output_hidden_states": False, "use_cache": False,}, + split_spec=split_spec, +) # Materialize each stage # [Note 3]: remove this for loop if you are running this script in a @@ -81,4 +87,5 @@ stage_module = pipe.get_stage_module(rank) print(f"Loading weights into stage {rank}") load_weights(stage_module) + stage_module.print_readable() From 1348e6528b5f4282a502bff715c5e77d5925c7b1 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 5 Aug 2024 05:49:47 -0700 Subject: [PATCH 3/4] Initialize buffers per init callbacks --- examples/llama/load_weights.py | 37 ++++++++++++++++++++++++++++++---- examples/llama/meta_init.py | 4 +++- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/examples/llama/load_weights.py b/examples/llama/load_weights.py index f79ef2740..a74388f80 100644 --- a/examples/llama/load_weights.py +++ b/examples/llama/load_weights.py @@ -3,7 +3,7 @@ import json import torch -from typing import Optional +from typing import Callable, Dict, Optional def load_weights( @@ -37,14 +37,19 @@ def load_weights( # files because the stage module is a partition of the full model. needed_files = set() for param in state_dict.keys(): - file = weight_map[param] - needed_files.add(file) + # The file a param is saved in + file = weight_map.setdefault(param, None) + if file: + needed_files.add(file) # Now we load the needed binary files for file in needed_files: checkpoint = torch.load(file, weights_only=True) for param in state_dict.keys(): - if weight_map[param] == file: + file_having_param = weight_map[param] + if file_having_param is None: + print(f"Cannot find checkpoint file for {param}, skipping") + elif file_having_param == file: state_dict[param] = checkpoint[param] updated_states.setdefault(param, None) @@ -59,3 +64,27 @@ def load_weights( # the current module are preserved. stage_module.load_state_dict(state_dict, assign=True) + +def init_buffers( + stage_module: torch.nn.Module, + device: torch.device, + init_callbacks: Dict[str, Callable], +): + """ + Initialize buffers of `stage_module` per the callback in `init_callbacks`. + `init_callbacks` is a dictionary from a buffer's FQN to its init function. + """ + for name, buf in stage_module.named_buffers(): + if name in init_callbacks: + cb = init_callbacks[name] + buf_val = cb(device) + # Find the parent module + splits = name.split(".") + mod = stage_module + for atom in splits[: -1]: + mod = getattr(mod, atom) + mod.register_buffer( + splits[-1], buf_val, persistent=False, + ) + print(f"Initialized buffer {name}") + diff --git a/examples/llama/meta_init.py b/examples/llama/meta_init.py index 629e9a60f..6f2b0197d 100644 --- a/examples/llama/meta_init.py +++ b/examples/llama/meta_init.py @@ -28,7 +28,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode from transformers import AutoModelForCausalLM, AutoTokenizer -from load_weights import load_weights +from load_weights import load_weights, init_buffers # Grab the model in meta/fake mode fake_mode = FakeTensorMode(allow_non_fake_inputs=True) @@ -87,5 +87,7 @@ stage_module = pipe.get_stage_module(rank) print(f"Loading weights into stage {rank}") load_weights(stage_module) + if hasattr(llama, "buf_init_callbacks"): + init_buffers(stage_module, "cpu", llama.buf_init_callbacks) stage_module.print_readable() From d09df81442b16deb86c148d647bcd483443a2429 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Tue, 6 Aug 2024 20:03:17 -0700 Subject: [PATCH 4/4] Add dtype option to init_buffers --- examples/llama/load_weights.py | 7 +++++-- examples/llama/meta_init.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/llama/load_weights.py b/examples/llama/load_weights.py index a74388f80..0c5bd3fee 100644 --- a/examples/llama/load_weights.py +++ b/examples/llama/load_weights.py @@ -67,8 +67,9 @@ def load_weights( def init_buffers( stage_module: torch.nn.Module, - device: torch.device, init_callbacks: Dict[str, Callable], + device: torch.device, + dtype: Optional[torch.dtype] = None, ): """ Initialize buffers of `stage_module` per the callback in `init_callbacks`. @@ -78,6 +79,8 @@ def init_buffers( if name in init_callbacks: cb = init_callbacks[name] buf_val = cb(device) + if dtype: + buf_val = buf_val.to(dtype) # Find the parent module splits = name.split(".") mod = stage_module @@ -86,5 +89,5 @@ def init_buffers( mod.register_buffer( splits[-1], buf_val, persistent=False, ) - print(f"Initialized buffer {name}") + print(f"Initialized buffer {name}, {buf_val.dtype}, {buf_val.device}") diff --git a/examples/llama/meta_init.py b/examples/llama/meta_init.py index 6f2b0197d..605d8a958 100644 --- a/examples/llama/meta_init.py +++ b/examples/llama/meta_init.py @@ -88,6 +88,6 @@ print(f"Loading weights into stage {rank}") load_weights(stage_module) if hasattr(llama, "buf_init_callbacks"): - init_buffers(stage_module, "cpu", llama.buf_init_callbacks) + init_buffers(stage_module, llama.buf_init_callbacks, "cpu", torch.float16) stage_module.print_readable()