Skip to content

Commit

Permalink
Compatibility with llama.py checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Sep 16, 2024
1 parent a185c50 commit a7051d1
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions src/nanotron/models/llama_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from nanotron.config.models_config import RandomInit, SpectralMupInit
from nanotron.generation.generate_store import AttachableStore
from nanotron.kernels.rope import liger_rotary_pos_emb
from nanotron.kernels.swiglu import LigerSiLUMulFunction
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
from nanotron.nn.activations import ACT2FN
Expand Down Expand Up @@ -157,24 +156,19 @@ def __init__(
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)

self.gate_proj = TensorParallelColumnLinear(
config.hidden_size,
config.intermediate_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
gate_up_contiguous_chunks = (
config.intermediate_size, # shape of gate_linear
config.intermediate_size, # shape of up_linear
)

self.up_proj = TensorParallelColumnLinear(
self.gate_up_proj = TensorParallelColumnLinear(
config.hidden_size,
config.intermediate_size,
2 * config.intermediate_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
)

self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
config.hidden_size,
Expand All @@ -183,10 +177,13 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
# TODO @nouamane: why can't we torch.jit.script GLUActivation?
self.split_silu_mul = GLUActivation(config.hidden_act)

def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]

return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(hidden_states), self.up_proj(hidden_states)))
merged_states = self.gate_up_proj(hidden_states)
hidden_states = self.down_proj(self.split_silu_mul(merged_states))
return hidden_states


class CausalSelfAttention(nn.Module, AttachableStore):
Expand Down

0 comments on commit a7051d1

Please sign in to comment.