Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Meta init llama then pipeline then materialize #1135

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Jul 23, 2024

Models can be big. Therefore we would need to:

  • create the model's "skeleton" on meta device
  • partition it so that it can fit on each device, and
  • materialize each partition.

This is a demo based on model Llama-2-7b-chat-hf and its checkpoint on Hugging Face Model Hub.

Before running the script, please download the following files in the same directory as this script:

  • pytorch_model.bin.index.json
  • pytorch_model-00001-of-00002.bin
  • pytorch_model-00002-of-00002.bin

Download link:
https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/tree/main

Your directory should look like this:
Screenshot 2024-07-23 at 7 44 35 AM

How to run this script:
$ python meta_init.py

I haven't used a distributed runtime, because I only have a MacBook at hand. But I tried to show how to load each stage module from HF checkpoints. Feel free to modify the script to run in a distributed way by distributing the for loop at [Note 3].

My torch version:
torch 2.5.0.dev20240722
I install it by:
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu

Cc: @lessw2020 @muellerzr @SunMarc @H-Huang @wconstab @LucasLLC

@kwen2501
Copy link
Contributor Author

kwen2501 commented Jul 23, 2024

Run logs:

(base) kw2501@kw2501-mbp llama % python meta_init.py

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
world_size=2
layers_per_rank = 16
Loading weights into stage 0
Fully updated state dict
class GraphModule(torch.nn.Module):
    def forward(self, input_ids: "i64[2, 4]"):
        # No stacktrace found for following nodes
        model = self.model(input_ids);  input_ids = None
        getitem: "i64[1, 4]" = model[0]
        getitem_1: "f32[4, 5]" = model[1]
        getitem_2: "f32[2, 4, 4096]" = model[2];  model = None
        return (getitem_2, getitem, getitem_1)
        
Loading weights into stage 1
Fully updated state dict
class GraphModule(torch.nn.Module):
    def forward(self, add_95: "f32[2, 4, 4096]", unsqueeze: "i64[1, 4]", mul: "f32[4, 5]"):
        # No stacktrace found for following nodes
        model: "f32[2, 4, 4096]" = self.model(mul, unsqueeze, add_95);  mul = unsqueeze = add_95 = None
        lm_head: "f32[2, 4, 32000]" = self.lm_head(model);  model = None
        
         # File: /opt/anaconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:1194 in forward, code: logits = logits.float()
        _to_copy_default_162: "f32[2, 4, 32000]" = torch.ops.aten._to_copy.default(lm_head, dtype = torch.float32);  lm_head = None
        return _to_copy_default_162

Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks helpful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants