Skip to content

Commit

Permalink
Bringing liger kernels back
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Sep 26, 2024
1 parent 3969aa2 commit df3ef9d
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 237 deletions.
135 changes: 0 additions & 135 deletions src/nanotron/kernels/swiglu.py

This file was deleted.

65 changes: 0 additions & 65 deletions src/nanotron/kernels/utils.py

This file was deleted.

39 changes: 2 additions & 37 deletions src/nanotron/models/llama_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nanotron.config import Config, LlamaConfig, ParallelismArgs
from nanotron.config.models_config import RandomInit, SpectralMupInit
from nanotron.generation.generate_store import AttachableStore
from nanotron.kernels.rope import liger_rotary_pos_emb
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
Expand Down Expand Up @@ -111,39 +112,6 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


# NOTE(tj.solergibert) FlashAttention RoPEs are faster (triton), but currently they don't support position_ids
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (torch.Tensor): The query tensor.
k (torch.Tensor): The key tensor.
cos (torch.Tensor): The cosine part of the rotary embedding.
sin (torch.Tensor): The sine part of the rotary embedding.
unsqueeze_dim (int, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
tuple (torch.Tensor) comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


def prepare_varlen_args(position_ids):
position_ids = position_ids.flatten()
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
Expand Down Expand Up @@ -336,11 +304,8 @@ def forward(
.contiguous()
) # [3, batch_size, seq_length, n_local_q_heads, d_qk]

# TODO(tj.solergibert) Apply RoPE embeddings WITHOUT too many transpose...
query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2)
# Apply RoPE
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
query_states, key_states = query_states.transpose(1, 2), key_states.transpose(1, 2)
query_states, key_states = liger_rotary_pos_emb(query_states, key_states, cos, sin)

# Prepare varlen args
cu_seqlens, max_seqlen_in_batch = prepare_varlen_args(position_ids)
Expand Down

0 comments on commit df3ef9d

Please sign in to comment.