Skip to content

Commit

Permalink
Modify reference code to support CPU inference
Browse files Browse the repository at this point in the history
  • Loading branch information
cglagovichTT committed Oct 18, 2024
1 parent c0be8c0 commit 7de6a77
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 65 deletions.
36 changes: 6 additions & 30 deletions models/llama3/reference_impl/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@

import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint

from ..api.args import ModelArgs
Expand Down Expand Up @@ -99,30 +94,15 @@ def build(
and loads the pre-trained model and tokenizer.
"""

if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")

if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)

torch.manual_seed(seed)

if local_rank > 0:
sys.stdout = open(os.devnull, "w")

start_time = time.time()

checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
ckpt_path = checkpoints[0]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
Expand All @@ -134,15 +114,11 @@ def build(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
assert model_args.vocab_size == tokenizer.n_words
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if model_args.vision_chunk_size > 0:
from .multimodal.model import CrossAttentionTransformer

model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
model.setup_cache(model_args.max_batch_size, torch.float32)
else:
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=True)
Expand Down Expand Up @@ -209,14 +185,14 @@ def generate(
)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id

if echo:
Expand All @@ -233,7 +209,7 @@ def generate(
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
prev_pos, cur_pos, dtype=torch.long, device="cuda"
prev_pos, cur_pos, dtype=torch.long
)
text_only_inference = model_input.vision is None
logits = self.model.forward(
Expand Down
30 changes: 20 additions & 10 deletions models/llama3/reference_impl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,8 @@
import math
from typing import Optional, Tuple

import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)
from torch import nn

from ..api import ModelArgs
Expand All @@ -27,7 +21,23 @@
# dependencies. These dependencies are not part of the default dependencies
# (requirements.txt) of the `llama-models` package.


class FakeParallelLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
bias=True,
gather_output=False,
input_is_parallel=False,
init_method=None):
super().__init__(in_features, out_features, bias=bias)

ColumnParallelLinear = RowParallelLinear = FakeParallelLinear

class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, init_method=None):
super().__init__(num_embeddings, embedding_dim)

class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down Expand Up @@ -116,7 +126,7 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
model_parallel_size = 1
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
Expand Down Expand Up @@ -158,15 +168,15 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
)

def forward(
self,
Expand Down
51 changes: 26 additions & 25 deletions models/llama3/reference_impl/multimodal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,13 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import fairscale.nn.model_parallel.initialize as fs_init

import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)

from PIL import Image as PIL_Image

from torch import nn, Tensor
from torch.distributed import _functional_collectives as funcol

from ..model import apply_rotary_emb, ModelArgs, precompute_freqs_cis, RMSNorm

Expand All @@ -42,6 +35,23 @@
from .utils import get_negative_inf_value, to_2tuple


class FakeParallelLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
bias=True,
gather_output=False,
input_is_parallel=False,
init_method=None):
super().__init__(in_features, out_features, bias=bias)

ColumnParallelLinear = RowParallelLinear = FakeParallelLinear

class VocabParallelEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, init_method=None):
super().__init__(num_embeddings, embedding_dim)

logger = logging.getLogger(__name__)
MP_SCALE = 8

Expand Down Expand Up @@ -131,7 +141,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._unfold(x)
x = x.permute(0, 2, 1)
x = F.linear(x, self._linear.weight)
x = gather_from_tensor_model_parallel_region(x)
return x


Expand Down Expand Up @@ -166,7 +175,6 @@ def forward(self, x):
hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias)
hidden = self.non_linearity(hidden)
hidden = F.linear(hidden, self.c_proj.weight)
hidden = reduce_from_tensor_model_parallel_region(hidden)
hidden += self.c_proj.bias
return hidden

Expand All @@ -179,7 +187,7 @@ def __init__(
n_heads,
):
super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size()
model_parallel_size = 1
qkvo_replication = 1
if model_parallel_size > 16:
qkvo_replication = model_parallel_size // 8
Expand Down Expand Up @@ -250,7 +258,6 @@ def forward(
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1)

out = F.linear(attn_output, self.wo.weight)
out = reduce_from_tensor_model_parallel_region(out)
out = out / self.qkvo_replication
return out

Expand Down Expand Up @@ -561,7 +568,7 @@ def __init__(self, args: ModelArgs):
cache_v (torch.Tensor): Cached values for attention.
"""
super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size()
model_parallel_size = 1
replication_factor = 1
if model_parallel_size > 8:
replication_factor = model_parallel_size // MP_SCALE
Expand Down Expand Up @@ -671,7 +678,6 @@ def forward(
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1)

out = F.linear(attn_output, self.wo.weight)
out = reduce_from_tensor_model_parallel_region(out)
return out


Expand Down Expand Up @@ -717,7 +723,6 @@ def forward(self, x):
x1 = F.silu(x1)
x_in = x1 * x3
out = F.linear(x_in, self.w2.weight)
out = reduce_from_tensor_model_parallel_region(out)
return out


Expand Down Expand Up @@ -874,7 +879,7 @@ def __init__(
norm_eps: float,
):
super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size()
self.model_parallel_size = 1
replication_factor = 1
if self.model_parallel_size > 8:
replication_factor = self.model_parallel_size // MP_SCALE
Expand Down Expand Up @@ -983,7 +988,6 @@ def forward(
output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, -1)

out = F.linear(output, self.wo.weight)
out = reduce_from_tensor_model_parallel_region(out)
return out


Expand Down Expand Up @@ -1113,13 +1117,12 @@ def forward(
# aspect_ratios: (B, T)
# h: (B, T, D)
vision_tokens = self.vision_encoder(
images.to(dtype=torch.bfloat16), aspect_ratios
images.to(dtype=torch.float32), aspect_ratios
)

vision_tokens = F.linear(
vision_tokens, self.vision_projection.weight, self.vision_projection.bias
)
vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens)
return vision_tokens


Expand All @@ -1128,7 +1131,7 @@ class CrossAttentionTransformerText(torch.nn.Module):

def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size()
self.model_parallel_size = 1
assert args.vocab_size > 0
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
Expand Down Expand Up @@ -1158,7 +1161,7 @@ def __init__(self, args: ModelArgs) -> None:
args.vision_num_cross_attention_layers
)
self.learnable_embedding = VocabParallelEmbedding(
max(fs_init.get_model_parallel_world_size(), 8),
8,
args.dim,
init_method=lambda x: x,
)
Expand Down Expand Up @@ -1270,10 +1273,9 @@ def forward(
h = self.norm(h)

output = F.linear(h, self.output.weight)
output = gather_from_tensor_model_parallel_region(output)
return output.float()

def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16):
def setup_cache(self, max_batch_size: int, dtype=torch.float32):
# Set up the text kv caches
device = next(self.parameters()).device
ones = torch.ones(
Expand Down Expand Up @@ -1407,7 +1409,6 @@ def compute_vision_tokens_masks(
else:
vision_tokens = self.vision_model(stacked_images, aspect_ratios)

vision_tokens = vision_tokens.to("cuda")

bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
xattn_caches = torch.stack(
Expand All @@ -1428,7 +1429,7 @@ def compute_vision_tokens_masks(
cross_attention_masks, full_text_row_masked_out_mask = (
self.text_model._get_xattn_mask(
num_tokens=total_len,
text_device="cuda",
text_device="cpu",
text_dtype=next(self.text_model.parameters()).dtype,
vision_tokens=vision_tokens,
cross_attention_masks=padded_masks,
Expand Down Expand Up @@ -1495,7 +1496,7 @@ def _pad_masks(
total_len: int,
max_num_chunks: int,
) -> torch.Tensor:
dtype = torch.bfloat16
dtype = torch.float32
inf_value = get_negative_inf_value(dtype)

bsz = len(all_masks)
Expand Down
Binary file added models/scripts/resources/clutter.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added models/scripts/resources/ocr_image.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 7de6a77

Please sign in to comment.