Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meta init llama then pipeline then materialize #1135

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions examples/llama/load_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

import json
import torch

from typing import Callable, Dict, Optional


def load_weights(
stage_module: torch.nn.Module,
weight_index_file: Optional[str] = "pytorch_model.bin.index.json",
):
"""
Load weights from Hugging Face checkpoints into a stage module.

This is a utility for Hugging Face ModelHub checkpoints that comes with an
index file and multiple binary files. The index file indicates which
parameter is saved in which binary. An example can be found at:
https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main

Please download the following files in the same directory as this script:
- pytorch_model.bin.index.json
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin
"""

state_dict = stage_module.state_dict()
updated_states = dict()

# Get the weight map -- a map from parameter name to file it is saved in
f = open(weight_index_file)
js = json.load(f)
weight_map = js["weight_map"]

# Figure the set of binary files we'd need to open in order to fill the
# state dict of the stage module. It will be a subset of all the binary
# files because the stage module is a partition of the full model.
needed_files = set()
for param in state_dict.keys():
# The file a param is saved in
file = weight_map.setdefault(param, None)
if file:
needed_files.add(file)

# Now we load the needed binary files
for file in needed_files:
checkpoint = torch.load(file, weights_only=True)
for param in state_dict.keys():
file_having_param = weight_map[param]
if file_having_param is None:
print(f"Cannot find checkpoint file for {param}, skipping")
elif file_having_param == file:
state_dict[param] = checkpoint[param]
updated_states.setdefault(param, None)

# Check if the module's state dict will be fully updated from checkpoint
if state_dict.keys() == updated_states.keys():
print("Fully updated state dict")
else:
print("Partially updated state dict")

# Now load the weights into the stage module
# We use `assign=True` because otherwise the properties of the tensors in
# the current module are preserved.
stage_module.load_state_dict(state_dict, assign=True)


def init_buffers(
stage_module: torch.nn.Module,
init_callbacks: Dict[str, Callable],
device: torch.device,
dtype: Optional[torch.dtype] = None,
):
"""
Initialize buffers of `stage_module` per the callback in `init_callbacks`.
`init_callbacks` is a dictionary from a buffer's FQN to its init function.
"""
for name, buf in stage_module.named_buffers():
if name in init_callbacks:
cb = init_callbacks[name]
buf_val = cb(device)
if dtype:
buf_val = buf_val.to(dtype)
# Find the parent module
splits = name.split(".")
mod = stage_module
for atom in splits[: -1]:
mod = getattr(mod, atom)
mod.register_buffer(
splits[-1], buf_val, persistent=False,
)
print(f"Initialized buffer {name}, {buf_val.dtype}, {buf_val.device}")

93 changes: 93 additions & 0 deletions examples/llama/meta_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

"""
This script shows how to create llama model in "meta" device mode, partition it
into pipeline stages, and materialize each stage modules from Hugging Face
checkpoints.

Before running the script, please download the following files in the same
directory as this script:
- pytorch_model.bin.index.json
- pytorch_model-00001-of-00002.bin
- pytorch_model-00002-of-00002.bin

Download link:
https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main

How to run this script:
$ python meta_init.py

I haven't used a distributed runtime, because I only want to showcase how to
load each stage module. Feel free to modify the script to run in a distributed
way by distributing the for loop at [Note 3].
"""

import os
import torch
from torch.distributed.pipelining import pipeline, SplitPoint
from torch._subclasses.fake_tensor import FakeTensorMode
from transformers import AutoModelForCausalLM, AutoTokenizer

from load_weights import load_weights, init_buffers

# Grab the model in meta/fake mode
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)

with torch.device("meta"):
llama = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf"
)

llama.eval()
print(llama)

# Cast the model to FakeTensor with real device (from meta device) because
# there is autocast code in llama. Autocast functions based on device of
# tensor. So we'd need to give it a real device instead of meta device.
with fake_mode:
# [Note 1]: set device to "cuda" if you are using GPUs
llama.to_empty(device="cpu")

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token = tokenizer.eos_token
prompts = (
"How do you", "I like to",
)

inputs = tokenizer(prompts, return_tensors="pt", padding=True)
real_ids = inputs["input_ids"]
# The example input needs to FakeTensor too
fake_ids = fake_mode.from_tensor(real_ids)

# Beginning of distributed
# [Note 2]: change world size here
world_size = 2
print(f"{world_size=}")

# Cut model by equal number of layers per rank
layers_per_rank = llama.config.num_hidden_layers // world_size
print(f"layers_per_rank = {layers_per_rank}")
split_spec = {
f"model.layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, world_size)
}

# Convert model into a pipeline
pipe = pipeline(
llama,
mb_args=(fake_ids,),
mb_kwargs={"output_attentions": False, "output_hidden_states": False, "use_cache": False,},
split_spec=split_spec,
)

# Materialize each stage
# [Note 3]: remove this for loop if you are running this script in a
# distributed manner
for rank in range(world_size):
stage_module = pipe.get_stage_module(rank)
print(f"Loading weights into stage {rank}")
load_weights(stage_module)
if hasattr(llama, "buf_init_callbacks"):
init_buffers(stage_module, llama.buf_init_callbacks, "cpu", torch.float16)
stage_module.print_readable()

Loading