From acd25b74f3f3a5da788af7b478d05b834d6d7e15 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 24 Oct 2024 05:33:32 +0000 Subject: [PATCH 01/73] First commit --- .../models/mamba2/modeling_mamba2.py | 2 +- src/transformers/models/zamba2/__init__.py | 57 + .../models/zamba2/modular_zamba2.py | 1599 +++++++++++++++++ 3 files changed, 1657 insertions(+), 1 deletion(-) create mode 100644 src/transformers/models/zamba2/__init__.py create mode 100644 src/transformers/models/zamba2/modular_zamba2.py diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 110ae09a388..23006a1419c 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -44,7 +44,7 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined else: - selective_state_update = None + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update diff --git a/src/transformers/models/zamba2/__init__.py b/src/transformers/models/zamba2/__init__.py new file mode 100644 index 00000000000..af01a5f2a64 --- /dev/null +++ b/src/transformers/models/zamba2/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_zamba2": ["Zamba2Config"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_zamba2"] = [ + "Zamba2ForCausalLM", + "Zamba2ForSequenceClassification", + "Zamba2Model", + "Zamba2PreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_zamba2 import Zamba2Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_zamba2 import ( + Zamba2ForCausalLM, + Zamba2ForSequenceClassification, + Zamba2Model, + Zamba2PreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py new file mode 100644 index 00000000000..c0149ae6a5b --- /dev/null +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -0,0 +1,1599 @@ +# coding=utf-8 +# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_mamba_ssm_available, + is_torchdynamo_compiling, +) + + +if is_mamba_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None, None, None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) +from ..zamba.modeling_zamba import ( + ZambaAttentionDecoderLayer, + ZambaMambaDecoderLayer, + ZambaAttention, + HybridLayer, + ZambaForCausalLM, + ZambaForSequenceClassification, + ZambaModel, + ZambaPreTrainedModel, + ZambaRMSNorm, + HybridMambaAttentionDynamicCache, + repeat_kv, +) +from ...configuration_utils import PretrainedConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Zamba2Config" + + +class Zamba2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Zamba2Model`]. It is used to instantiate a + Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Zamba2 model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Zamba2Model`] + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + ffn_hidden_size (`int`, *optional*, defaults to 4 * hidden_size): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 54): + Number of hidden layers in the model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). + mamba_headdim (``, *optional*, defaults to 64): + dimension of each Mamba2 heads (number of heads is set to 1 in this implementation). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and + `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if + `True` and kernels are not available + state_size (`int`, *optional*, defaults to 16): + The dimension the mamba state space latents + conv_dimension (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + expansion_factor (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + add_bias_linear (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in various layers + gated_linear_units (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use gated MLP + use_shared_block_lora (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to add (unshared) LoRA modules to the first layer of the MLP + inside the shared transformer blocks + state_size (`int`, *optional*, defaults to 128): + The rank of the LoRA modules inside the MLP of the shared transformer blocks + """ + + model_type = "zamba2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + max_position_embeddings=4096, + tie_word_embeddings=True, + hidden_size=2560, + num_hidden_layers=54, + + state_size=64, + conv_dimension=4, + expansion_factor=2, + mamba_headdim=64, + mamba_ngroups=1, + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + + add_bias_linear=False, + ffn_intermediate_size=None, + gated_linear_unit=True, + ffn_hidden_act="gelu", + num_attention_heads=32, + num_key_value_heads=None, + sliding_window=None, + attention_dropout=0.0, + + num_mem_blocks=1, + use_shared_block_lora=True, + use_shared_attention_lora=False, + lora_rank=128, + use_mem_eff_path=True, + use_mem_rope=False, + rope_theta=10000, + attention_hidden_size=None, + + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + + ft_lora = False, + use_long_context=False, + **kwargs, + ): + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + if ffn_intermediate_size is None: + self.ffn_intermediate_size = 4 * hidden_size + else: + self.ffn_intermediate_size = ffn_intermediate_size + self.ffn_hidden_act = ffn_hidden_act + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.num_mem_blocks = num_mem_blocks + self.use_mem_rope = use_mem_rope + self.rope_theta = rope_theta + if attention_hidden_size is None: + self.attention_hidden_size = 2 * hidden_size + else: + self.attention_hidden_size = attention_hidden_size + self.attention_dropout = attention_dropout + self.state_size = state_size + self.conv_dimension = conv_dimension + self.expansion_factor = expansion_factor + self.add_bias_linear = add_bias_linear + self.mamba_headdim = mamba_headdim + self.mamba_ngroups = mamba_ngroups + self.gated_linear_unit = gated_linear_unit + self.use_shared_block_lora = use_shared_block_lora + self.use_shared_attention_lora = use_shared_attention_lora + self.lora_rank = lora_rank + self.use_long_context=use_long_context + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_floor = time_step_floor + if use_long_context: + self.max_position_embeddings = 16384 + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.num_attention_heads = num_attention_heads + self.kv_channels = self.hidden_size // self.num_attention_heads + self.num_query_groups = self.num_attention_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + if ffn_intermediate_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + self.use_mem_eff_path = use_mem_eff_path + + + # Below, "m" stands for mamba layer, "h" stands for hybrid layer (composed by a shared transformer followed by mamba layer) + self.layers_block_type = ['m', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'h', 'm', 'm'] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def count_mem_blocks_in_config(config: Zamba2Config): + """ + Count number of shared blocks + """ + num_gs = 0 + for val in config.layers_block_type: + if val == 'h': + num_gs +=1 + return num_gs + + +def layer_type_list(config: Zamba2Config): + """ + Returns list of layer ids containing hybrid layers + """ + ll = [] + for i, val in enumerate(config.layers_block_type): + if val == 'h': + ll.append(i) + return ll + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Zamba +class Zamba2RMSNorm(ZambaRMSNorm): + pass + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(Zamba2RMSNorm) + + +class Zamba2RotaryEmbedding(nn.Module): + def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + if config.use_long_context: + a = 8 + base = base * a ** (dim / (dim-2)) + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + + @torch.no_grad() + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Zamba2Attention(ZambaAttention): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + + Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: + The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads. + The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + Additionally, replaced + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) + """ + + def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_mem_blocks = None): + super().__init__(config, layer_idx) + self.num_mem_blocks = num_mem_blocks + self.rope_theta = config.rope_theta + self.layer_block_map = layer_type_list(config) + + ### add to config: + # config.attention_hidden_size + # config.attention_head_dim + # config.max_position_embeddings + if config.use_shared_attention_lora: + self.linear_q_lora_A_list = nn.ParameterList([]) + self.linear_q_lora_B_list = nn.ParameterList([]) + self.linear_k_lora_A_list = nn.ParameterList([]) + self.linear_k_lora_B_list = nn.ParameterList([]) + self.linear_v_lora_A_list = nn.ParameterList([]) + self.linear_v_lora_B_list = nn.ParameterList([]) + + for i in range(self.num_mem_blocks): + linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) + linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + self.linear_q_lora_A_list.append(linear_q_lora_A) + self.linear_q_lora_B_list.append(linear_q_lora_B) + linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) + linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + self.linear_k_lora_A_list.append(linear_k_lora_A) + self.linear_k_lora_B_list.append(linear_k_lora_B) + linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) + linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + self.linear_v_lora_A_list.append(linear_v_lora_A) + self.linear_v_lora_B_list.append(linear_v_lora_B) + + if config.use_mem_rope: + self.rotary_emb = Zamba2RotaryEmbedding( + config, + self.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.use_shared_attention_lora: + linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + q_lora_output = linear_q_lora_A(hidden_states) + q_lora_output = linear_q_lora_B(q_lora_output) + query_states = self.q_proj(hidden_states) + query_states = query_states + q_lora_output + linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + k_lora_output = linear_k_lora_A(hidden_states) + k_lora_output = linear_k_lora_B(k_lora_output) + key_states = self.k_proj(hidden_states) + key_states = key_states + k_lora_output + linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + v_lora_output = linear_v_lora_A(hidden_states) + v_lora_output = linear_v_lora_B(v_lora_output) + value_states = self.v_proj(hidden_states) + value_states = value_states + v_lora_output + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.config.use_mem_rope: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + layer_idx = self.layer_block_map[layer_idx] + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward +# dropped use_sliding_windows from the arguments of self._flash_attention_forward +class Zamba2FlashAttention2(Zamba2Attention): + """ + Zamba2 flash attention module. This module inherits from `Zamba2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + if self.config.use_shared_attention_lora: + linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + q_lora_output = linear_q_lora_A(hidden_states) + q_lora_output = linear_q_lora_B(q_lora_output) + query_states = self.q_proj(hidden_states) + query_states = query_states + q_lora_output + linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + k_lora_output = linear_k_lora_A(hidden_states) + k_lora_output = linear_k_lora_B(k_lora_output) + key_states = self.k_proj(hidden_states) + key_states = key_states + k_lora_output + linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + v_lora_output = linear_v_lora_A(hidden_states) + v_lora_output = linear_v_lora_B(v_lora_output) + value_states = self.v_proj(hidden_states) + value_states = value_states + v_lora_output + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.config.use_mem_rope: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + layer_idx = self.layer_block_map[layer_idx] + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + softmax_scale = 1 / math.sqrt(self.head_dim / 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=softmax_scale, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention +class Zamba2SdpaAttention(Zamba2Attention): + """ + Zamba2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Zamba2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Zamba2Model is using Zamba2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.use_shared_attention_lora: + linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + q_lora_output = linear_q_lora_A(hidden_states) + q_lora_output = linear_q_lora_B(q_lora_output) + query_states = self.q_proj(hidden_states) + query_states = query_states + q_lora_output + linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + k_lora_output = linear_k_lora_A(hidden_states) + k_lora_output = linear_k_lora_B(k_lora_output) + key_states = self.k_proj(hidden_states) + key_states = key_states + k_lora_output + linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + v_lora_output = linear_v_lora_A(hidden_states) + v_lora_output = linear_v_lora_B(v_lora_output) + value_states = self.v_proj(hidden_states) + value_states = value_states + v_lora_output + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.config.use_mem_rope: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + softmax_scale = 1 / math.sqrt(self.head_dim / 2) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + scale=softmax_scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +ZAMBA2_ATTENTION_CLASSES = { + "eager": Zamba2Attention, + "flash_attention_2": Zamba2FlashAttention2, + "sdpa": Zamba2SdpaAttention, +} + + +class Zamba2Mamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: Zamba2Config, layer_idx: int = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_dimension + self.intermediate_size = int(config.expand * self.hidden_size) + self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias # add this with default True + self.activation = "silu" + self.act = nn.SiLU() + + self.layer_norm_epsilon = config.layer_norm_epsilon + self.rms_norm = config.rms_norm + + self.n_groups = config.mamba_ngroups + self.head_dim = config.mamba_headdim + self.num_heads = self.intermediate_size // self.head_dim + self.chunk_size = config.chunk_size # add this with default 256 + + self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf")) + self.time_step_min = config.time_step_min # add this, with same default as zamba1 + self.time_step_max = config.time_step_max # add this, with same default as zamba1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=True, + kernel_size=config.conv_dimension, + groups=self.conv_dim, + padding=config.conv_dimension - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.add_bias_linear, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated(self.intermediate_size, eps=1e-5) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # set up dimensions for reshapes later + + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads + + # getting projected states from cache if it exists + if cache_params is not None and cache_params.seqlen_offset > 0: + in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 + split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] + _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1) + + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + A = -torch.exp(self.A_log.float()) # (nheads,) + + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + out = self.out_proj(hidden_states)[:, None, ...] + # if no cache is found, calling the kernel + else: + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states) + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + if self.training and cache_params is None: + out, ssm_state = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=True, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, time_step = torch.split( + projected_states, + [self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, + ) + + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2)[:, :seq_len] + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # Gated MLP's linear projection + projected_states = self.in_proj(input_states.squeeze(1)) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 + _, _, gate, hidden_states, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + if cache_params.seqlen_offset > 0: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + # handle batched generation - states are copied through + conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + hidden_states = hidden_states.transpose(1,2) + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + else: + ssm_state = torch.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) + A = -torch.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_params.seqlen_offset > 0: + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = torch.exp(dt[..., None] * A) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = dB * hidden_states[..., None] + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_min) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # First, contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + + # Step 2: Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Step 3: Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + + # (right term of low-rank factorization of off-diagonal blocks; B terms) + + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] + # permute back B * decay states + states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) + if cache_params is not None and cache_params.seqlen_offset > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + + states_permuted = states.permute(0, 2, 1, 3, 4) + result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) + new_states = result.permute(0, 2, 1, 3, 4) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + # compute Yoff + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class Zamba2MLP(nn.Module): + def __init__(self, config: Zamba2Config, num_mem_blocks = None): + """ + Shared MLP layer. To the intermediate activations of the MLP, we add un-shared LoRA's, which + introduce some amount of diversification across the shared MLP layers. + """ + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_mem_blocks = num_mem_blocks + self.ffn_intermediate_size = config.ffn_intermediate_size + + self.act_fn = ACT2FN[config.ffn_hidden_act] + def gated_act_fn(x): + x = torch.chunk(x, 2, dim=-1) + + return self.act_fn(x[0]) * x[1] + self.gated_act_fn = gated_act_fn + + self.linear_fc1 = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias = config.add_bias_linear) + self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) + + if self.config.use_shared_block_lora: + self.linear_fc1_lora_A_list = nn.ModuleList([]) + self.linear_fc1_lora_B_list = nn.ModuleList([]) + for i in range(self.num_mem_blocks): + linear_fc1_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) + linear_fc1_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias = False) + self.linear_fc1_lora_A_list.append(linear_fc1_lora_A) + self.linear_fc1_lora_B_list.append(linear_fc1_lora_B) + + def forward(self, hidden_state, layer_idx = None): + if self.config.use_shared_block_lora: + linear_fc1_lora_A = self.linear_fc1_lora_A_list[layer_idx] + linear_fc1_lora_B = self.linear_fc1_lora_B_list[layer_idx] + lora_output = linear_fc1_lora_A(hidden_state) + lora_output = linear_fc1_lora_B(lora_output) + intermediate_state = self.linear_fc1(hidden_state) + hidden_state = intermediate_state + lora_output + else: + hidden_state = self.linear_fc1(hidden_state) + + hidden_state = self.gated_act_fn(hidden_state) + output = self.down_proj(hidden_state) + return output + + +class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer): + def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + num_gs = count_mem_blocks_in_config(config) + self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=-1, num_mem_blocks = num_gs) + self.feed_forward = Zamba2MLP(config, num_mem_blocks = num_gs) + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. + This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The + concatenated tensor is then used as input of the pre-attention RMSNorm + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states, layer_idx) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Zamba2Mamba2DecoderLayer(ZambaMambaDecoderLayer): + def __init__(self, config: Zamba2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.mamba = Zamba2Mamba2Mixer(config=config, layer_idx=layer_idx) + + +class Zamba2HybridLayer(HybridLayer): + pass + + +ZAMBA2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Zamba2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", + ZAMBA2_START_DOCSTRING, +) +class Zamba2PreTrainedModel(ZambaPreTrainedModel): + _supports_flash_attn_2 = True + # Leaving this commented out for now until testing + # _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2Mamba2DecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Zamba2Mamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + num_heads = int(self.config.expand * self.config.hidden_size) // self.config.mamba_headdim + dt = torch.exp( + torch.rand(num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + @classmethod + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + hard_check_only: bool = False, + check_device_map: bool = False, + ): + """ + Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` + as we do not want to disable Flash Attention 2 in Zamba2. + """ + return super(ZambaPreTrainedModel)._check_and_enable_flash_attn_2( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) + + +ZAMBA2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. + See the `HybridMambaAttentionDynamicCache` class for more details. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", + ZAMBA2_START_DOCSTRING, +) +class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ZambaDecoderLayer`] + + Args: + config: ZambaConfig + """ + + def __init__(self, config: Zamba2Config): + super().__init__(config) + + blocks = torch.nn.ModuleList([Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)]) + ### got to here + mamba_layers = [] + linear_layers = [] + self.layers_block_type = config.layers_block_type + for i in range(config.num_hidden_layers): + if config.layers_block_type[i] == "m": + mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + elif config.layers_block_type[i] == "h": + linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)) + mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + mamba_layers = iter(mamba_layers) + linear_layers = iter(linear_layers) + layers = [] + self._tied_weights_keys = [] + for layer_id, layer_type in enumerate(self.layers_block_type): + if layer_type == "h": + prefix_name = f"layers.{layer_id}." + tied_keys = [ + "shared_transf.self_attn.q_proj.weight", + "shared_transf.self_attn.k_proj.weight", + "shared_transf.self_attn.v_proj.weight", + "shared_transf.self_attn.o_proj.weight", + "shared_transf.feed_forward.gate_proj.weight", + "shared_transf.feed_forward.up_proj.weight", + "shared_transf.feed_forward.down_proj.weight", + "shared_transf.input_layernorm.weight", + "shared_transf.pre_ff_layernorm.weight", + ] + self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers))) + else: + layers.append(next(mamba_layers)) + self.layers = nn.ModuleList(layers) + + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + + +# Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA +class Zamba2ForCausalLM(ZambaForCausalLM, Zamba2PreTrainedModel, GenerationMixin): + def __init__(self, config: Zamba2Config): + super().__init__(config) + self.model = Zamba2Model(config) + self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys] + + # Initialize weights and apply final processing + self.post_init() + + +@add_start_docstrings( + """ + The Zamba2 Model with a sequence classification head on top (linear layer). + + [`Zamba2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + ZAMBA2_START_DOCSTRING, +) +class Zamba2ForSequenceClassification(ZambaForSequenceClassification, Zamba2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = Zamba2Model(config) + self._tied_weights_keys = self.model._tied_weights_keys + + # Initialize weights and apply final processing + self.post_init() \ No newline at end of file From 70639b842bc6bde44baad9738f310d1a74ae463a Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 28 Oct 2024 22:55:56 +0000 Subject: [PATCH 02/73] Finish model implementation --- .../models/zamba2/modular_zamba2.py | 144 +++++++++++------- 1 file changed, 86 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index c0149ae6a5b..7e9dccf713a 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -15,6 +15,7 @@ # limitations under the License. import math from typing import Any, Dict, List, Optional, Tuple, Union +from itertools import cycle import torch import torch.utils.checkpoint @@ -144,9 +145,9 @@ class Zamba2Config(PretrainedConfig): `True` and kernels are not available state_size (`int`, *optional*, defaults to 16): The dimension the mamba state space latents - conv_dimension (`int`, *optional*, defaults to 4): + mamba_d_conv (`int`, *optional*, defaults to 4): The size of the mamba convolution kernel - expansion_factor (`int`, *optional*, defaults to 2): + mamba_expand (`int`, *optional*, defaults to 2): Expanding factor (relative to hidden_size) used to determine the mamba intermediate size add_bias_linear (`bool`, *optional*, defaults to `False`): Flag indicating whether or not to use bias in various layers @@ -170,19 +171,29 @@ def __init__( hidden_size=2560, num_hidden_layers=54, - state_size=64, - conv_dimension=4, - expansion_factor=2, + mamba_d_state=64, + mamba_d_conv=4, + mamba_expand=2, mamba_headdim=64, mamba_ngroups=1, time_step_min=0.001, time_step_max=0.1, time_step_floor=1e-4, + time_step_limit=(0.0, float("inf")), + + mamba_dt_rank="auto", + n_mamba_heads=1, + mamba_conv_bias=True, + mamba_proj_bias=False, + hidden_mamba_act="silu", + use_mamba_kernels=True, + use_conv_bias=True, + chunk_size=256, add_bias_linear=False, - ffn_intermediate_size=None, + intermediate_size=None, gated_linear_unit=True, - ffn_hidden_act="gelu", + hidden_act="gelu", num_attention_heads=32, num_key_value_heads=None, sliding_window=None, @@ -196,6 +207,7 @@ def __init__( use_mem_rope=False, rope_theta=10000, attention_hidden_size=None, + attention_head_dim=None, initializer_range=0.02, rms_norm_eps=1e-5, @@ -205,7 +217,6 @@ def __init__( bos_token_id=1, eos_token_id=2, - ft_lora = False, use_long_context=False, **kwargs, ): @@ -214,11 +225,11 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.tie_word_embeddings = tie_word_embeddings self.hidden_size = hidden_size - if ffn_intermediate_size is None: - self.ffn_intermediate_size = 4 * hidden_size + if intermediate_size is None: + self.intermediate_size = 4 * hidden_size else: - self.ffn_intermediate_size = ffn_intermediate_size - self.ffn_hidden_act = ffn_hidden_act + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window @@ -229,13 +240,27 @@ def __init__( self.attention_hidden_size = 2 * hidden_size else: self.attention_hidden_size = attention_hidden_size + if attention_head_dim is None: + self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads + else: + self.attention_head_dim = attention_head_dim self.attention_dropout = attention_dropout - self.state_size = state_size - self.conv_dimension = conv_dimension - self.expansion_factor = expansion_factor + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank self.add_bias_linear = add_bias_linear self.mamba_headdim = mamba_headdim self.mamba_ngroups = mamba_ngroups + self.n_mamba_heads = n_mamba_heads + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.hidden_mamba_act = hidden_mamba_act + self.use_mamba_kernels = use_mamba_kernels + self.use_conv_bias = use_conv_bias + self.chunk_size = chunk_size + self.time_step_limit = time_step_limit + self.gated_linear_unit = gated_linear_unit self.use_shared_block_lora = use_shared_block_lora self.use_shared_attention_lora = use_shared_attention_lora @@ -258,7 +283,7 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps - if ffn_intermediate_size is None: + if intermediate_size is None: self.ffn_hidden_size = 4 * self.hidden_size self.use_cache = use_cache @@ -266,8 +291,8 @@ def __init__( self.use_mem_eff_path = use_mem_eff_path - # Below, "m" stands for mamba layer, "h" stands for hybrid layer (composed by a shared transformer followed by mamba layer) - self.layers_block_type = ['m', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'h', 'm', 'm'] + # Below, "mmamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) + self.layers_block_type = ['mamba'] + (['mamba'] * 5 + ['hybrid']) * 7 + ['mamba'] * 4 + ['hybrid'] + ['mamba'] * 3 + ['hybrid'] + ['mamba'] * 2 super().__init__( pad_token_id=pad_token_id, @@ -284,7 +309,7 @@ def count_mem_blocks_in_config(config: Zamba2Config): """ num_gs = 0 for val in config.layers_block_type: - if val == 'h': + if val == 'hybrid': num_gs +=1 return num_gs @@ -295,7 +320,7 @@ def layer_type_list(config: Zamba2Config): """ ll = [] for i, val in enumerate(config.layers_block_type): - if val == 'h': + if val == 'hybrid': ll.append(i) return ll @@ -501,6 +526,8 @@ def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_me base=self.rope_theta, ) + self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} + def forward( self, hidden_states: torch.Tensor, @@ -516,20 +543,21 @@ def forward( bsz, q_len, _ = hidden_states.size() if self.config.use_shared_attention_lora: - linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] - linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + lora_layer_idx = self.layer_dic[layer_idx] + linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx] q_lora_output = linear_q_lora_A(hidden_states) q_lora_output = linear_q_lora_B(q_lora_output) query_states = self.q_proj(hidden_states) query_states = query_states + q_lora_output - linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] - linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + linear_k_lora_A = self.linear_k_lora_A_list[lora_layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[lora_layer_idx] k_lora_output = linear_k_lora_A(hidden_states) k_lora_output = linear_k_lora_B(k_lora_output) key_states = self.k_proj(hidden_states) key_states = key_states + k_lora_output - linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] - linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + linear_v_lora_A = self.linear_v_lora_A_list[lora_layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[lora_layer_idx] v_lora_output = linear_v_lora_A(hidden_states) v_lora_output = linear_v_lora_B(v_lora_output) value_states = self.v_proj(hidden_states) @@ -834,22 +862,18 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): super().__init__() self.config = config self.hidden_size = config.hidden_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_dimension - self.intermediate_size = int(config.expand * self.hidden_size) - self.time_step_rank = int(config.time_step_rank) + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * self.hidden_size) self.layer_idx = layer_idx self.use_conv_bias = config.use_conv_bias # add this with default True self.activation = "silu" self.act = nn.SiLU() - self.layer_norm_epsilon = config.layer_norm_epsilon - self.rms_norm = config.rms_norm - self.n_groups = config.mamba_ngroups self.head_dim = config.mamba_headdim self.num_heads = self.intermediate_size // self.head_dim - self.chunk_size = config.chunk_size # add this with default 256 + self.chunk_size = config.chunk_size self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf")) self.time_step_min = config.time_step_min # add this, with same default as zamba1 @@ -860,9 +884,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): in_channels=self.conv_dim, out_channels=self.conv_dim, bias=True, - kernel_size=config.conv_dimension, + kernel_size=config.mamba_d_conv, groups=self.conv_dim, - padding=config.conv_dimension - 1, + padding=config.mamba_d_conv - 1, ) # projection of the input hidden states @@ -1252,37 +1276,41 @@ def __init__(self, config: Zamba2Config, num_mem_blocks = None): self.config = config self.hidden_size = config.hidden_size self.num_mem_blocks = num_mem_blocks - self.ffn_intermediate_size = config.ffn_intermediate_size + self.ffn_intermediate_size = config.intermediate_size - self.act_fn = ACT2FN[config.ffn_hidden_act] + self.act_fn = ACT2FN[config.hidden_act] def gated_act_fn(x): x = torch.chunk(x, 2, dim=-1) return self.act_fn(x[0]) * x[1] self.gated_act_fn = gated_act_fn - self.linear_fc1 = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias = config.add_bias_linear) + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias = config.add_bias_linear) self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) if self.config.use_shared_block_lora: - self.linear_fc1_lora_A_list = nn.ModuleList([]) - self.linear_fc1_lora_B_list = nn.ModuleList([]) + self.gate_up_proj_lora_A_list = nn.ModuleList([]) + self.gate_up_proj_lora_B_list = nn.ModuleList([]) for i in range(self.num_mem_blocks): - linear_fc1_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) - linear_fc1_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias = False) - self.linear_fc1_lora_A_list.append(linear_fc1_lora_A) - self.linear_fc1_lora_B_list.append(linear_fc1_lora_B) + gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) + gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias = False) + self.gate_up_proj_lora_A_list.append(gate_up_proj_lora_A) + self.gate_up_proj_lora_B_list.append(gate_up_proj_lora_B) + + layer_block_map = layer_type_list(config) + self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx = None): if self.config.use_shared_block_lora: - linear_fc1_lora_A = self.linear_fc1_lora_A_list[layer_idx] - linear_fc1_lora_B = self.linear_fc1_lora_B_list[layer_idx] - lora_output = linear_fc1_lora_A(hidden_state) - lora_output = linear_fc1_lora_B(lora_output) - intermediate_state = self.linear_fc1(hidden_state) + layer_idx = self.layer_dic[layer_idx] + gate_up_proj_lora_A = self.gate_up_proj_lora_A_list[layer_idx] + gate_up_proj_lora_B = self.gate_up_proj_lora_B_list[layer_idx] + lora_output = gate_up_proj_lora_A(hidden_state) + lora_output = gate_up_proj_lora_B(lora_output) + intermediate_state = self.gate_up_proj(hidden_state) hidden_state = intermediate_state + lora_output else: - hidden_state = self.linear_fc1(hidden_state) + hidden_state = self.gate_up_proj(hidden_state) hidden_state = self.gated_act_fn(hidden_state) output = self.down_proj(hidden_state) @@ -1406,7 +1434,7 @@ def _init_weights(self, module): module.A_log._no_weight_decay = True module.D._no_weight_decay = True - num_heads = int(self.config.expand * self.config.hidden_size) // self.config.mamba_headdim + num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim dt = torch.exp( torch.rand(num_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) @@ -1522,23 +1550,23 @@ class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): def __init__(self, config: Zamba2Config): super().__init__(config) - blocks = torch.nn.ModuleList([Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)]) - ### got to here + blocks = [Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)] mamba_layers = [] linear_layers = [] self.layers_block_type = config.layers_block_type for i in range(config.num_hidden_layers): - if config.layers_block_type[i] == "m": + if config.layers_block_type[i] == "mamba": mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) - elif config.layers_block_type[i] == "h": + elif config.layers_block_type[i] == "hybrid": linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)) mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) mamba_layers = iter(mamba_layers) linear_layers = iter(linear_layers) + blocks = cycle(blocks) layers = [] self._tied_weights_keys = [] for layer_id, layer_type in enumerate(self.layers_block_type): - if layer_type == "h": + if layer_type == "hybrid": prefix_name = f"layers.{layer_id}." tied_keys = [ "shared_transf.self_attn.q_proj.weight", @@ -1552,7 +1580,7 @@ def __init__(self, config: Zamba2Config): "shared_transf.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers))) + layers.append(HybridLayer(next(blocks), next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) self.layers = nn.ModuleList(layers) From d111b98886f8b626d77c92904ce5b02a2002c9f2 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 24 Oct 2024 05:33:32 +0000 Subject: [PATCH 03/73] First commit --- .../models/mamba2/modeling_mamba2.py | 2 +- src/transformers/models/zamba2/__init__.py | 57 + .../models/zamba2/modular_zamba2.py | 1599 +++++++++++++++++ 3 files changed, 1657 insertions(+), 1 deletion(-) create mode 100644 src/transformers/models/zamba2/__init__.py create mode 100644 src/transformers/models/zamba2/modular_zamba2.py diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index c312b9b9435..8661495dbf6 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -44,7 +44,7 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined else: - selective_state_update = None + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update diff --git a/src/transformers/models/zamba2/__init__.py b/src/transformers/models/zamba2/__init__.py new file mode 100644 index 00000000000..af01a5f2a64 --- /dev/null +++ b/src/transformers/models/zamba2/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = { + "configuration_zamba2": ["Zamba2Config"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_zamba2"] = [ + "Zamba2ForCausalLM", + "Zamba2ForSequenceClassification", + "Zamba2Model", + "Zamba2PreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_zamba2 import Zamba2Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_zamba2 import ( + Zamba2ForCausalLM, + Zamba2ForSequenceClassification, + Zamba2Model, + Zamba2PreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py new file mode 100644 index 00000000000..c0149ae6a5b --- /dev/null +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -0,0 +1,1599 @@ +# coding=utf-8 +# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_mamba_ssm_available, + is_torchdynamo_compiling, +) + + +if is_mamba_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + selective_state_update = None, None, None + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) +from ..zamba.modeling_zamba import ( + ZambaAttentionDecoderLayer, + ZambaMambaDecoderLayer, + ZambaAttention, + HybridLayer, + ZambaForCausalLM, + ZambaForSequenceClassification, + ZambaModel, + ZambaPreTrainedModel, + ZambaRMSNorm, + HybridMambaAttentionDynamicCache, + repeat_kv, +) +from ...configuration_utils import PretrainedConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Zamba2Config" + + +class Zamba2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Zamba2Model`]. It is used to instantiate a + Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Zamba2 model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Zamba2Model`] + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + ffn_hidden_size (`int`, *optional*, defaults to 4 * hidden_size): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 54): + Number of hidden layers in the model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). + mamba_headdim (``, *optional*, defaults to 64): + dimension of each Mamba2 heads (number of heads is set to 1 in this implementation). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and + `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if + `True` and kernels are not available + state_size (`int`, *optional*, defaults to 16): + The dimension the mamba state space latents + conv_dimension (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + expansion_factor (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + add_bias_linear (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in various layers + gated_linear_units (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use gated MLP + use_shared_block_lora (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to add (unshared) LoRA modules to the first layer of the MLP + inside the shared transformer blocks + state_size (`int`, *optional*, defaults to 128): + The rank of the LoRA modules inside the MLP of the shared transformer blocks + """ + + model_type = "zamba2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + max_position_embeddings=4096, + tie_word_embeddings=True, + hidden_size=2560, + num_hidden_layers=54, + + state_size=64, + conv_dimension=4, + expansion_factor=2, + mamba_headdim=64, + mamba_ngroups=1, + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + + add_bias_linear=False, + ffn_intermediate_size=None, + gated_linear_unit=True, + ffn_hidden_act="gelu", + num_attention_heads=32, + num_key_value_heads=None, + sliding_window=None, + attention_dropout=0.0, + + num_mem_blocks=1, + use_shared_block_lora=True, + use_shared_attention_lora=False, + lora_rank=128, + use_mem_eff_path=True, + use_mem_rope=False, + rope_theta=10000, + attention_hidden_size=None, + + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + + ft_lora = False, + use_long_context=False, + **kwargs, + ): + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + if ffn_intermediate_size is None: + self.ffn_intermediate_size = 4 * hidden_size + else: + self.ffn_intermediate_size = ffn_intermediate_size + self.ffn_hidden_act = ffn_hidden_act + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.num_mem_blocks = num_mem_blocks + self.use_mem_rope = use_mem_rope + self.rope_theta = rope_theta + if attention_hidden_size is None: + self.attention_hidden_size = 2 * hidden_size + else: + self.attention_hidden_size = attention_hidden_size + self.attention_dropout = attention_dropout + self.state_size = state_size + self.conv_dimension = conv_dimension + self.expansion_factor = expansion_factor + self.add_bias_linear = add_bias_linear + self.mamba_headdim = mamba_headdim + self.mamba_ngroups = mamba_ngroups + self.gated_linear_unit = gated_linear_unit + self.use_shared_block_lora = use_shared_block_lora + self.use_shared_attention_lora = use_shared_attention_lora + self.lora_rank = lora_rank + self.use_long_context=use_long_context + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_floor = time_step_floor + if use_long_context: + self.max_position_embeddings = 16384 + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.num_attention_heads = num_attention_heads + self.kv_channels = self.hidden_size // self.num_attention_heads + self.num_query_groups = self.num_attention_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + if ffn_intermediate_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + self.use_mem_eff_path = use_mem_eff_path + + + # Below, "m" stands for mamba layer, "h" stands for hybrid layer (composed by a shared transformer followed by mamba layer) + self.layers_block_type = ['m', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'h', 'm', 'm'] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +def count_mem_blocks_in_config(config: Zamba2Config): + """ + Count number of shared blocks + """ + num_gs = 0 + for val in config.layers_block_type: + if val == 'h': + num_gs +=1 + return num_gs + + +def layer_type_list(config: Zamba2Config): + """ + Returns list of layer ids containing hybrid layers + """ + ll = [] + for i, val in enumerate(config.layers_block_type): + if val == 'h': + ll.append(i) + return ll + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Zamba +class Zamba2RMSNorm(ZambaRMSNorm): + pass + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(Zamba2RMSNorm) + + +class Zamba2RotaryEmbedding(nn.Module): + def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + if config.use_long_context: + a = 8 + base = base * a ** (dim / (dim-2)) + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + + @torch.no_grad() + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Zamba2Attention(ZambaAttention): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + + Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: + The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads. + The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + Additionally, replaced + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) + """ + + def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_mem_blocks = None): + super().__init__(config, layer_idx) + self.num_mem_blocks = num_mem_blocks + self.rope_theta = config.rope_theta + self.layer_block_map = layer_type_list(config) + + ### add to config: + # config.attention_hidden_size + # config.attention_head_dim + # config.max_position_embeddings + if config.use_shared_attention_lora: + self.linear_q_lora_A_list = nn.ParameterList([]) + self.linear_q_lora_B_list = nn.ParameterList([]) + self.linear_k_lora_A_list = nn.ParameterList([]) + self.linear_k_lora_B_list = nn.ParameterList([]) + self.linear_v_lora_A_list = nn.ParameterList([]) + self.linear_v_lora_B_list = nn.ParameterList([]) + + for i in range(self.num_mem_blocks): + linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) + linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + self.linear_q_lora_A_list.append(linear_q_lora_A) + self.linear_q_lora_B_list.append(linear_q_lora_B) + linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) + linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + self.linear_k_lora_A_list.append(linear_k_lora_A) + self.linear_k_lora_B_list.append(linear_k_lora_B) + linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) + linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + self.linear_v_lora_A_list.append(linear_v_lora_A) + self.linear_v_lora_B_list.append(linear_v_lora_B) + + if config.use_mem_rope: + self.rotary_emb = Zamba2RotaryEmbedding( + config, + self.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.use_shared_attention_lora: + linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + q_lora_output = linear_q_lora_A(hidden_states) + q_lora_output = linear_q_lora_B(q_lora_output) + query_states = self.q_proj(hidden_states) + query_states = query_states + q_lora_output + linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + k_lora_output = linear_k_lora_A(hidden_states) + k_lora_output = linear_k_lora_B(k_lora_output) + key_states = self.k_proj(hidden_states) + key_states = key_states + k_lora_output + linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + v_lora_output = linear_v_lora_A(hidden_states) + v_lora_output = linear_v_lora_B(v_lora_output) + value_states = self.v_proj(hidden_states) + value_states = value_states + v_lora_output + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.config.use_mem_rope: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + layer_idx = self.layer_block_map[layer_idx] + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward +# dropped use_sliding_windows from the arguments of self._flash_attention_forward +class Zamba2FlashAttention2(Zamba2Attention): + """ + Zamba2 flash attention module. This module inherits from `Zamba2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + if self.config.use_shared_attention_lora: + linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + q_lora_output = linear_q_lora_A(hidden_states) + q_lora_output = linear_q_lora_B(q_lora_output) + query_states = self.q_proj(hidden_states) + query_states = query_states + q_lora_output + linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + k_lora_output = linear_k_lora_A(hidden_states) + k_lora_output = linear_k_lora_B(k_lora_output) + key_states = self.k_proj(hidden_states) + key_states = key_states + k_lora_output + linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + v_lora_output = linear_v_lora_A(hidden_states) + v_lora_output = linear_v_lora_B(v_lora_output) + value_states = self.v_proj(hidden_states) + value_states = value_states + v_lora_output + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.config.use_mem_rope: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + layer_idx = self.layer_block_map[layer_idx] + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + softmax_scale = 1 / math.sqrt(self.head_dim / 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=softmax_scale, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention +class Zamba2SdpaAttention(Zamba2Attention): + """ + Zamba2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Zamba2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Zamba2Model is using Zamba2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.use_shared_attention_lora: + linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + q_lora_output = linear_q_lora_A(hidden_states) + q_lora_output = linear_q_lora_B(q_lora_output) + query_states = self.q_proj(hidden_states) + query_states = query_states + q_lora_output + linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + k_lora_output = linear_k_lora_A(hidden_states) + k_lora_output = linear_k_lora_B(k_lora_output) + key_states = self.k_proj(hidden_states) + key_states = key_states + k_lora_output + linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + v_lora_output = linear_v_lora_A(hidden_states) + v_lora_output = linear_v_lora_B(v_lora_output) + value_states = self.v_proj(hidden_states) + value_states = value_states + v_lora_output + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.config.use_mem_rope: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + softmax_scale = 1 / math.sqrt(self.head_dim / 2) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + scale=softmax_scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +ZAMBA2_ATTENTION_CLASSES = { + "eager": Zamba2Attention, + "flash_attention_2": Zamba2FlashAttention2, + "sdpa": Zamba2SdpaAttention, +} + + +class Zamba2Mamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: Zamba2Config, layer_idx: int = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_dimension + self.intermediate_size = int(config.expand * self.hidden_size) + self.time_step_rank = int(config.time_step_rank) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias # add this with default True + self.activation = "silu" + self.act = nn.SiLU() + + self.layer_norm_epsilon = config.layer_norm_epsilon + self.rms_norm = config.rms_norm + + self.n_groups = config.mamba_ngroups + self.head_dim = config.mamba_headdim + self.num_heads = self.intermediate_size // self.head_dim + self.chunk_size = config.chunk_size # add this with default 256 + + self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf")) + self.time_step_min = config.time_step_min # add this, with same default as zamba1 + self.time_step_max = config.time_step_max # add this, with same default as zamba1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=True, + kernel_size=config.conv_dimension, + groups=self.conv_dim, + padding=config.conv_dimension - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.add_bias_linear, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated(self.intermediate_size, eps=1e-5) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # set up dimensions for reshapes later + + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads + + # getting projected states from cache if it exists + if cache_params is not None and cache_params.seqlen_offset > 0: + in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 + split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] + _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1) + + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + A = -torch.exp(self.A_log.float()) # (nheads,) + + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + out = self.out_proj(hidden_states)[:, None, ...] + # if no cache is found, calling the kernel + else: + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states) + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + if self.training and cache_params is None: + out, ssm_state = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=True, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, time_step = torch.split( + projected_states, + [self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, + ) + + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2)[:, :seq_len] + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # Gated MLP's linear projection + projected_states = self.in_proj(input_states.squeeze(1)) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 + _, _, gate, hidden_states, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + if cache_params.seqlen_offset > 0: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + # handle batched generation - states are copied through + conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + hidden_states = hidden_states.transpose(1,2) + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + else: + ssm_state = torch.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) + A = -torch.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_params.seqlen_offset > 0: + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = torch.exp(dt[..., None] * A) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = dB * hidden_states[..., None] + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_min) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # First, contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + + # Step 2: Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Step 3: Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + + # (right term of low-rank factorization of off-diagonal blocks; B terms) + + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] + # permute back B * decay states + states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) + if cache_params is not None and cache_params.seqlen_offset > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + + states_permuted = states.permute(0, 2, 1, 3, 4) + result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) + new_states = result.permute(0, 2, 1, 3, 4) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + # compute Yoff + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class Zamba2MLP(nn.Module): + def __init__(self, config: Zamba2Config, num_mem_blocks = None): + """ + Shared MLP layer. To the intermediate activations of the MLP, we add un-shared LoRA's, which + introduce some amount of diversification across the shared MLP layers. + """ + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_mem_blocks = num_mem_blocks + self.ffn_intermediate_size = config.ffn_intermediate_size + + self.act_fn = ACT2FN[config.ffn_hidden_act] + def gated_act_fn(x): + x = torch.chunk(x, 2, dim=-1) + + return self.act_fn(x[0]) * x[1] + self.gated_act_fn = gated_act_fn + + self.linear_fc1 = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias = config.add_bias_linear) + self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) + + if self.config.use_shared_block_lora: + self.linear_fc1_lora_A_list = nn.ModuleList([]) + self.linear_fc1_lora_B_list = nn.ModuleList([]) + for i in range(self.num_mem_blocks): + linear_fc1_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) + linear_fc1_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias = False) + self.linear_fc1_lora_A_list.append(linear_fc1_lora_A) + self.linear_fc1_lora_B_list.append(linear_fc1_lora_B) + + def forward(self, hidden_state, layer_idx = None): + if self.config.use_shared_block_lora: + linear_fc1_lora_A = self.linear_fc1_lora_A_list[layer_idx] + linear_fc1_lora_B = self.linear_fc1_lora_B_list[layer_idx] + lora_output = linear_fc1_lora_A(hidden_state) + lora_output = linear_fc1_lora_B(lora_output) + intermediate_state = self.linear_fc1(hidden_state) + hidden_state = intermediate_state + lora_output + else: + hidden_state = self.linear_fc1(hidden_state) + + hidden_state = self.gated_act_fn(hidden_state) + output = self.down_proj(hidden_state) + return output + + +class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer): + def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + num_gs = count_mem_blocks_in_config(config) + self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=-1, num_mem_blocks = num_gs) + self.feed_forward = Zamba2MLP(config, num_mem_blocks = num_gs) + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. + This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The + concatenated tensor is then used as input of the pre-attention RMSNorm + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states, layer_idx) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Zamba2Mamba2DecoderLayer(ZambaMambaDecoderLayer): + def __init__(self, config: Zamba2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.mamba = Zamba2Mamba2Mixer(config=config, layer_idx=layer_idx) + + +class Zamba2HybridLayer(HybridLayer): + pass + + +ZAMBA2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Zamba2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", + ZAMBA2_START_DOCSTRING, +) +class Zamba2PreTrainedModel(ZambaPreTrainedModel): + _supports_flash_attn_2 = True + # Leaving this commented out for now until testing + # _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2Mamba2DecoderLayer"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Zamba2Mamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + num_heads = int(self.config.expand * self.config.hidden_size) // self.config.mamba_headdim + dt = torch.exp( + torch.rand(num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + @classmethod + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + hard_check_only: bool = False, + check_device_map: bool = False, + ): + """ + Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` + as we do not want to disable Flash Attention 2 in Zamba2. + """ + return super(ZambaPreTrainedModel)._check_and_enable_flash_attn_2( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) + + +ZAMBA2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. + See the `HybridMambaAttentionDynamicCache` class for more details. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", + ZAMBA2_START_DOCSTRING, +) +class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ZambaDecoderLayer`] + + Args: + config: ZambaConfig + """ + + def __init__(self, config: Zamba2Config): + super().__init__(config) + + blocks = torch.nn.ModuleList([Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)]) + ### got to here + mamba_layers = [] + linear_layers = [] + self.layers_block_type = config.layers_block_type + for i in range(config.num_hidden_layers): + if config.layers_block_type[i] == "m": + mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + elif config.layers_block_type[i] == "h": + linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)) + mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + mamba_layers = iter(mamba_layers) + linear_layers = iter(linear_layers) + layers = [] + self._tied_weights_keys = [] + for layer_id, layer_type in enumerate(self.layers_block_type): + if layer_type == "h": + prefix_name = f"layers.{layer_id}." + tied_keys = [ + "shared_transf.self_attn.q_proj.weight", + "shared_transf.self_attn.k_proj.weight", + "shared_transf.self_attn.v_proj.weight", + "shared_transf.self_attn.o_proj.weight", + "shared_transf.feed_forward.gate_proj.weight", + "shared_transf.feed_forward.up_proj.weight", + "shared_transf.feed_forward.down_proj.weight", + "shared_transf.input_layernorm.weight", + "shared_transf.pre_ff_layernorm.weight", + ] + self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers))) + else: + layers.append(next(mamba_layers)) + self.layers = nn.ModuleList(layers) + + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + + +# Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA +class Zamba2ForCausalLM(ZambaForCausalLM, Zamba2PreTrainedModel, GenerationMixin): + def __init__(self, config: Zamba2Config): + super().__init__(config) + self.model = Zamba2Model(config) + self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys] + + # Initialize weights and apply final processing + self.post_init() + + +@add_start_docstrings( + """ + The Zamba2 Model with a sequence classification head on top (linear layer). + + [`Zamba2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + ZAMBA2_START_DOCSTRING, +) +class Zamba2ForSequenceClassification(ZambaForSequenceClassification, Zamba2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = Zamba2Model(config) + self._tied_weights_keys = self.model._tied_weights_keys + + # Initialize weights and apply final processing + self.post_init() \ No newline at end of file From 8f36dba7d9b147f7ef78d2592eee4bbd8d2df437 Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 28 Oct 2024 22:55:56 +0000 Subject: [PATCH 04/73] Finish model implementation --- .../models/zamba2/modular_zamba2.py | 144 +++++++++++------- 1 file changed, 86 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index c0149ae6a5b..7e9dccf713a 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -15,6 +15,7 @@ # limitations under the License. import math from typing import Any, Dict, List, Optional, Tuple, Union +from itertools import cycle import torch import torch.utils.checkpoint @@ -144,9 +145,9 @@ class Zamba2Config(PretrainedConfig): `True` and kernels are not available state_size (`int`, *optional*, defaults to 16): The dimension the mamba state space latents - conv_dimension (`int`, *optional*, defaults to 4): + mamba_d_conv (`int`, *optional*, defaults to 4): The size of the mamba convolution kernel - expansion_factor (`int`, *optional*, defaults to 2): + mamba_expand (`int`, *optional*, defaults to 2): Expanding factor (relative to hidden_size) used to determine the mamba intermediate size add_bias_linear (`bool`, *optional*, defaults to `False`): Flag indicating whether or not to use bias in various layers @@ -170,19 +171,29 @@ def __init__( hidden_size=2560, num_hidden_layers=54, - state_size=64, - conv_dimension=4, - expansion_factor=2, + mamba_d_state=64, + mamba_d_conv=4, + mamba_expand=2, mamba_headdim=64, mamba_ngroups=1, time_step_min=0.001, time_step_max=0.1, time_step_floor=1e-4, + time_step_limit=(0.0, float("inf")), + + mamba_dt_rank="auto", + n_mamba_heads=1, + mamba_conv_bias=True, + mamba_proj_bias=False, + hidden_mamba_act="silu", + use_mamba_kernels=True, + use_conv_bias=True, + chunk_size=256, add_bias_linear=False, - ffn_intermediate_size=None, + intermediate_size=None, gated_linear_unit=True, - ffn_hidden_act="gelu", + hidden_act="gelu", num_attention_heads=32, num_key_value_heads=None, sliding_window=None, @@ -196,6 +207,7 @@ def __init__( use_mem_rope=False, rope_theta=10000, attention_hidden_size=None, + attention_head_dim=None, initializer_range=0.02, rms_norm_eps=1e-5, @@ -205,7 +217,6 @@ def __init__( bos_token_id=1, eos_token_id=2, - ft_lora = False, use_long_context=False, **kwargs, ): @@ -214,11 +225,11 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.tie_word_embeddings = tie_word_embeddings self.hidden_size = hidden_size - if ffn_intermediate_size is None: - self.ffn_intermediate_size = 4 * hidden_size + if intermediate_size is None: + self.intermediate_size = 4 * hidden_size else: - self.ffn_intermediate_size = ffn_intermediate_size - self.ffn_hidden_act = ffn_hidden_act + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.sliding_window = sliding_window @@ -229,13 +240,27 @@ def __init__( self.attention_hidden_size = 2 * hidden_size else: self.attention_hidden_size = attention_hidden_size + if attention_head_dim is None: + self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads + else: + self.attention_head_dim = attention_head_dim self.attention_dropout = attention_dropout - self.state_size = state_size - self.conv_dimension = conv_dimension - self.expansion_factor = expansion_factor + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank self.add_bias_linear = add_bias_linear self.mamba_headdim = mamba_headdim self.mamba_ngroups = mamba_ngroups + self.n_mamba_heads = n_mamba_heads + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.hidden_mamba_act = hidden_mamba_act + self.use_mamba_kernels = use_mamba_kernels + self.use_conv_bias = use_conv_bias + self.chunk_size = chunk_size + self.time_step_limit = time_step_limit + self.gated_linear_unit = gated_linear_unit self.use_shared_block_lora = use_shared_block_lora self.use_shared_attention_lora = use_shared_attention_lora @@ -258,7 +283,7 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps - if ffn_intermediate_size is None: + if intermediate_size is None: self.ffn_hidden_size = 4 * self.hidden_size self.use_cache = use_cache @@ -266,8 +291,8 @@ def __init__( self.use_mem_eff_path = use_mem_eff_path - # Below, "m" stands for mamba layer, "h" stands for hybrid layer (composed by a shared transformer followed by mamba layer) - self.layers_block_type = ['m', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'm', 'h', 'm', 'm', 'm', 'h', 'm', 'm'] + # Below, "mmamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) + self.layers_block_type = ['mamba'] + (['mamba'] * 5 + ['hybrid']) * 7 + ['mamba'] * 4 + ['hybrid'] + ['mamba'] * 3 + ['hybrid'] + ['mamba'] * 2 super().__init__( pad_token_id=pad_token_id, @@ -284,7 +309,7 @@ def count_mem_blocks_in_config(config: Zamba2Config): """ num_gs = 0 for val in config.layers_block_type: - if val == 'h': + if val == 'hybrid': num_gs +=1 return num_gs @@ -295,7 +320,7 @@ def layer_type_list(config: Zamba2Config): """ ll = [] for i, val in enumerate(config.layers_block_type): - if val == 'h': + if val == 'hybrid': ll.append(i) return ll @@ -501,6 +526,8 @@ def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_me base=self.rope_theta, ) + self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} + def forward( self, hidden_states: torch.Tensor, @@ -516,20 +543,21 @@ def forward( bsz, q_len, _ = hidden_states.size() if self.config.use_shared_attention_lora: - linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] - linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + lora_layer_idx = self.layer_dic[layer_idx] + linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx] q_lora_output = linear_q_lora_A(hidden_states) q_lora_output = linear_q_lora_B(q_lora_output) query_states = self.q_proj(hidden_states) query_states = query_states + q_lora_output - linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] - linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + linear_k_lora_A = self.linear_k_lora_A_list[lora_layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[lora_layer_idx] k_lora_output = linear_k_lora_A(hidden_states) k_lora_output = linear_k_lora_B(k_lora_output) key_states = self.k_proj(hidden_states) key_states = key_states + k_lora_output - linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] - linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + linear_v_lora_A = self.linear_v_lora_A_list[lora_layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[lora_layer_idx] v_lora_output = linear_v_lora_A(hidden_states) v_lora_output = linear_v_lora_B(v_lora_output) value_states = self.v_proj(hidden_states) @@ -834,22 +862,18 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): super().__init__() self.config = config self.hidden_size = config.hidden_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_dimension - self.intermediate_size = int(config.expand * self.hidden_size) - self.time_step_rank = int(config.time_step_rank) + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * self.hidden_size) self.layer_idx = layer_idx self.use_conv_bias = config.use_conv_bias # add this with default True self.activation = "silu" self.act = nn.SiLU() - self.layer_norm_epsilon = config.layer_norm_epsilon - self.rms_norm = config.rms_norm - self.n_groups = config.mamba_ngroups self.head_dim = config.mamba_headdim self.num_heads = self.intermediate_size // self.head_dim - self.chunk_size = config.chunk_size # add this with default 256 + self.chunk_size = config.chunk_size self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf")) self.time_step_min = config.time_step_min # add this, with same default as zamba1 @@ -860,9 +884,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): in_channels=self.conv_dim, out_channels=self.conv_dim, bias=True, - kernel_size=config.conv_dimension, + kernel_size=config.mamba_d_conv, groups=self.conv_dim, - padding=config.conv_dimension - 1, + padding=config.mamba_d_conv - 1, ) # projection of the input hidden states @@ -1252,37 +1276,41 @@ def __init__(self, config: Zamba2Config, num_mem_blocks = None): self.config = config self.hidden_size = config.hidden_size self.num_mem_blocks = num_mem_blocks - self.ffn_intermediate_size = config.ffn_intermediate_size + self.ffn_intermediate_size = config.intermediate_size - self.act_fn = ACT2FN[config.ffn_hidden_act] + self.act_fn = ACT2FN[config.hidden_act] def gated_act_fn(x): x = torch.chunk(x, 2, dim=-1) return self.act_fn(x[0]) * x[1] self.gated_act_fn = gated_act_fn - self.linear_fc1 = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias = config.add_bias_linear) + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias = config.add_bias_linear) self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) if self.config.use_shared_block_lora: - self.linear_fc1_lora_A_list = nn.ModuleList([]) - self.linear_fc1_lora_B_list = nn.ModuleList([]) + self.gate_up_proj_lora_A_list = nn.ModuleList([]) + self.gate_up_proj_lora_B_list = nn.ModuleList([]) for i in range(self.num_mem_blocks): - linear_fc1_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) - linear_fc1_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias = False) - self.linear_fc1_lora_A_list.append(linear_fc1_lora_A) - self.linear_fc1_lora_B_list.append(linear_fc1_lora_B) + gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) + gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias = False) + self.gate_up_proj_lora_A_list.append(gate_up_proj_lora_A) + self.gate_up_proj_lora_B_list.append(gate_up_proj_lora_B) + + layer_block_map = layer_type_list(config) + self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx = None): if self.config.use_shared_block_lora: - linear_fc1_lora_A = self.linear_fc1_lora_A_list[layer_idx] - linear_fc1_lora_B = self.linear_fc1_lora_B_list[layer_idx] - lora_output = linear_fc1_lora_A(hidden_state) - lora_output = linear_fc1_lora_B(lora_output) - intermediate_state = self.linear_fc1(hidden_state) + layer_idx = self.layer_dic[layer_idx] + gate_up_proj_lora_A = self.gate_up_proj_lora_A_list[layer_idx] + gate_up_proj_lora_B = self.gate_up_proj_lora_B_list[layer_idx] + lora_output = gate_up_proj_lora_A(hidden_state) + lora_output = gate_up_proj_lora_B(lora_output) + intermediate_state = self.gate_up_proj(hidden_state) hidden_state = intermediate_state + lora_output else: - hidden_state = self.linear_fc1(hidden_state) + hidden_state = self.gate_up_proj(hidden_state) hidden_state = self.gated_act_fn(hidden_state) output = self.down_proj(hidden_state) @@ -1406,7 +1434,7 @@ def _init_weights(self, module): module.A_log._no_weight_decay = True module.D._no_weight_decay = True - num_heads = int(self.config.expand * self.config.hidden_size) // self.config.mamba_headdim + num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim dt = torch.exp( torch.rand(num_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) @@ -1522,23 +1550,23 @@ class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): def __init__(self, config: Zamba2Config): super().__init__(config) - blocks = torch.nn.ModuleList([Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)]) - ### got to here + blocks = [Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)] mamba_layers = [] linear_layers = [] self.layers_block_type = config.layers_block_type for i in range(config.num_hidden_layers): - if config.layers_block_type[i] == "m": + if config.layers_block_type[i] == "mamba": mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) - elif config.layers_block_type[i] == "h": + elif config.layers_block_type[i] == "hybrid": linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)) mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) mamba_layers = iter(mamba_layers) linear_layers = iter(linear_layers) + blocks = cycle(blocks) layers = [] self._tied_weights_keys = [] for layer_id, layer_type in enumerate(self.layers_block_type): - if layer_type == "h": + if layer_type == "hybrid": prefix_name = f"layers.{layer_id}." tied_keys = [ "shared_transf.self_attn.q_proj.weight", @@ -1552,7 +1580,7 @@ def __init__(self, config: Zamba2Config): "shared_transf.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers))) + layers.append(HybridLayer(next(blocks), next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) self.layers = nn.ModuleList(layers) From 700fbf0378a6b79ba67ae7111e9c3cca9e3c2d14 Mon Sep 17 00:00:00 2001 From: pglorio Date: Wed, 30 Oct 2024 17:47:27 +0000 Subject: [PATCH 05/73] Register zamba2 --- docs/source/en/index.md | 1 + docs/source/en/model_doc/zamba2.md | 93 +++++++++++++++++++ src/transformers/__init__.py | 16 ++++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 3 + .../models/auto/tokenization_auto.py | 7 ++ src/transformers/utils/dummy_pt_objects.py | 28 ++++++ 8 files changed, 151 insertions(+) create mode 100644 docs/source/en/model_doc/zamba2.md diff --git a/docs/source/en/index.md b/docs/source/en/index.md index aaff45ab65d..22fe641ed0c 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -364,6 +364,7 @@ Flax), PyTorch, and/or TensorFlow. | [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ | | [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ | | [Zamba](model_doc/zamba) | ✅ | ❌ | ❌ | +| [Zamba2](model_doc/zamba2) | ✅ | ❌ | ❌ | | [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/zamba2.md b/docs/source/en/model_doc/zamba2.md new file mode 100644 index 00000000000..75333555d45 --- /dev/null +++ b/docs/source/en/model_doc/zamba2.md @@ -0,0 +1,93 @@ + +# Zamba2 + +Zamba2 is a large language model (LLM) trained by Zyphra, and made available under an Apache 2.0 license. Please see the [Zyphra Hugging Face](https://huggingface.co/collections/zyphra/) repository for model weights. + +This model was contributed by [pglo](https://huggingface.co/pglo). + + +## Model details + +Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space models (Specifically [Mamba](https://github.com/state-spaces/mamba)) and transformer, and were trained using next-token prediction. Zamba2 uses shared transformer layers after every 6 mamba blocks. It uses the [Mistral v0.1 tokenizer](https://huggingface.co/mistralai/Mistral-7B-v0.1). We came to this architecture after a series of ablations at small scales. Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B were pre-trained on 2T and 3T tokens, respectively. + + + +## Quick start + + +### Presequities + +Zamba2 requires you use `transformers` version 4.46.0 or higher: +```bash +pip install transformers>=4.46.0 +``` + +## Inference + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B") +model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16) + +input_text = "What factors contributed to the fall of the Roman Empire?" +input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") + +outputs = model.generate(**input_ids, max_new_tokens=100) +print(tokenizer.decode(outputs[0])) +``` + + +## Model card + +The model cards can be found at: +* [Zamba2-1.2B](https://huggingface.co/Zyphra/Zamba2-1.2B) +* [Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B) +* [Zamba2-7B](https://huggingface.co/Zyphra/Zamba2-7B) + + +## Issues +For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/Zyphra/Zamba2-7B/discussions) + + +## License + +The model weights are open-sourced via an Apache 2.0 license. + + +## Zamba2Config + +[[autodoc]] Zamba2Config + + +## Zamba2Model + +[[autodoc]] Zamba2Model + - forward + + +## Zamba2ForCausalLM + +[[autodoc]] Zamba2ForCausalLM + - forward + + +## Zamba2ForSequenceClassification + +[[autodoc]] transformers.Zamba2ForSequenceClassification + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cc8b0739502..cd069bbf6de 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -855,6 +855,7 @@ "models.yolos": ["YolosConfig"], "models.yoso": ["YosoConfig"], "models.zamba": ["ZambaConfig"], + "models.zamba2": ["Zamba2Config"], "models.zoedepth": ["ZoeDepthConfig"], "onnx": [], "pipelines": [ @@ -3803,6 +3804,14 @@ "ZambaPreTrainedModel", ] ) + _import_structure["models.zamba2"].extend( + [ + "Zamba2ForCausalLM", + "Zamba2ForSequenceClassification", + "Zamba2Model", + "Zamba2PreTrainedModel", + ] + ) _import_structure["models.zoedepth"].extend( [ "ZoeDepthForDepthEstimation", @@ -5780,6 +5789,7 @@ from .models.yolos import YolosConfig from .models.yoso import YosoConfig from .models.zamba import ZambaConfig + from .models.zamba2 import Zamba2Config from .models.zoedepth import ZoeDepthConfig # Pipelines @@ -8207,6 +8217,12 @@ ZambaModel, ZambaPreTrainedModel, ) + from .models.zamba2 import ( + Zamba2ForCausalLM, + Zamba2ForSequenceClassification, + Zamba2Model, + Zamba2PreTrainedModel, + ) from .models.zoedepth import ( ZoeDepthForDepthEstimation, ZoeDepthPreTrainedModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 9155f629e63..4a73a410e71 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -285,5 +285,6 @@ yolos, yoso, zamba, + zamba2, zoedepth, ) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 48625ea3f34..970b70f6877 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -314,6 +314,7 @@ ("yolos", "YolosConfig"), ("yoso", "YosoConfig"), ("zamba", "ZambaConfig"), + ("zamba2", "Zamba2Config"), ("zoedepth", "ZoeDepthConfig"), ] ) @@ -637,6 +638,7 @@ ("yolos", "YOLOS"), ("yoso", "YOSO"), ("zamba", "Zamba"), + ("zamba2", "Zamba2"), ("zoedepth", "ZoeDepth"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 67c539fca66..ff8452e5bd8 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -287,6 +287,7 @@ ("yolos", "YolosModel"), ("yoso", "YosoModel"), ("zamba", "ZambaModel"), + ("zamba2", "Zamba2Model"), ] ) @@ -552,6 +553,7 @@ ("xlnet", "XLNetLMHeadModel"), ("xmod", "XmodForCausalLM"), ("zamba", "ZambaForCausalLM"), + ("zamba2", "Zamba2ForCausalLM"), ] ) @@ -1008,6 +1010,7 @@ ("xmod", "XmodForSequenceClassification"), ("yoso", "YosoForSequenceClassification"), ("zamba", "ZambaForSequenceClassification"), + ("zamba2", "Zamba2ForSequenceClassification"), ] ) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 7674ea51a53..11c35c0cda7 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -566,6 +566,13 @@ "LlamaTokenizerFast" if is_tokenizers_available() else None, ), ), + ( + "zamba2", + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), + ), ] ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 36e1ff2cfe6..2af11b49a5f 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -10139,6 +10139,34 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Zamba2ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Zamba2ForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Zamba2Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Zamba2PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class ZoeDepthForDepthEstimation(metaclass=DummyObject): _backends = ["torch"] From 70a602198a66d76df0093bc7c3687804605aceaa Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 4 Nov 2024 23:57:24 +0000 Subject: [PATCH 06/73] generated modeling and configuration --- .../models/zamba/modeling_zamba.py | 4 +- .../models/zamba2/configuration_zamba2.py | 245 ++ .../models/zamba2/modeling_zamba2.py | 2129 +++++++++++++++++ .../models/zamba2/modular_zamba2.py | 104 +- 4 files changed, 2468 insertions(+), 14 deletions(-) create mode 100644 src/transformers/models/zamba2/configuration_zamba2.py create mode 100644 src/transformers/models/zamba2/modeling_zamba2.py diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index dee7f898fcf..4e97116b563 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -922,7 +922,7 @@ def forward( return outputs -class HybridLayer(nn.Module): +class ZambaHybridLayer(nn.Module): def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer): super().__init__() self.shared_transf = shared_transf @@ -1201,7 +1201,7 @@ def __init__(self, config: ZambaConfig): "shared_transf.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers))) + layers.append(ZambaHybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) self.layers = nn.ModuleList(layers) diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py new file mode 100644 index 00000000000..7801bd8e3ab --- /dev/null +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -0,0 +1,245 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zamba2/modular_zamba2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zamba2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +from ...configuration_utils import PretrainedConfig + + +class Zamba2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Zamba2Model`]. It is used to instantiate a + Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Zamba2 model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Zamba2Model`] + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 2560): + Dimension of the hidden representations. + ffn_hidden_size (`int`, *optional*, defaults to 4 * hidden_size): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 54): + Number of hidden layers in the model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). + mamba_headdim (``, *optional*, defaults to 64): + dimension of each Mamba2 heads (number of heads is set to 1 in this implementation). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and + `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if + `True` and kernels are not available + state_size (`int`, *optional*, defaults to 16): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + add_bias_linear (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in various layers + gated_linear_units (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use gated MLP + use_shared_block_lora (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to add (unshared) LoRA modules to the first layer of the MLP + inside the shared transformer blocks + state_size (`int`, *optional*, defaults to 128): + The rank of the LoRA modules inside the MLP of the shared transformer blocks + """ + + model_type = "zamba2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + max_position_embeddings=4096, + tie_word_embeddings=True, + hidden_size=2560, + num_hidden_layers=54, + mamba_d_state=64, + mamba_d_conv=4, + mamba_expand=2, + mamba_headdim=64, + mamba_ngroups=1, + time_step_min=0.001, + time_step_max=0.1, + time_step_floor=1e-4, + time_step_limit=(0.0, float("inf")), + mamba_dt_rank="auto", + n_mamba_heads=1, + mamba_conv_bias=True, + mamba_proj_bias=False, + hidden_mamba_act="silu", + use_mamba_kernels=True, + use_conv_bias=True, + chunk_size=256, + add_bias_linear=False, + intermediate_size=None, + gated_linear_unit=True, + hidden_act="gelu", + num_attention_heads=32, + num_key_value_heads=None, + sliding_window=None, + attention_dropout=0.0, + num_mem_blocks=1, + use_shared_block_lora=True, + use_shared_attention_lora=False, + lora_rank=128, + use_mem_eff_path=True, + use_mem_rope=False, + rope_theta=10000, + attention_hidden_size=None, + attention_head_dim=None, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + use_long_context=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + if intermediate_size is None: + self.intermediate_size = 4 * hidden_size + else: + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.num_mem_blocks = num_mem_blocks + self.use_mem_rope = use_mem_rope + self.rope_theta = rope_theta + if attention_hidden_size is None: + self.attention_hidden_size = 2 * hidden_size + else: + self.attention_hidden_size = attention_hidden_size + if attention_head_dim is None: + self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads + else: + self.attention_head_dim = attention_head_dim + self.attention_dropout = attention_dropout + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank + self.add_bias_linear = add_bias_linear + self.mamba_headdim = mamba_headdim + self.mamba_ngroups = mamba_ngroups + self.n_mamba_heads = n_mamba_heads + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.hidden_mamba_act = hidden_mamba_act + self.use_mamba_kernels = use_mamba_kernels + self.use_conv_bias = use_conv_bias + self.chunk_size = chunk_size + self.time_step_limit = time_step_limit + + self.gated_linear_unit = gated_linear_unit + self.use_shared_block_lora = use_shared_block_lora + self.use_shared_attention_lora = use_shared_attention_lora + self.lora_rank = lora_rank + self.use_long_context = use_long_context + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.time_step_floor = time_step_floor + if use_long_context: + self.max_position_embeddings = 16384 + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.num_attention_heads = num_attention_heads + self.kv_channels = self.hidden_size // self.num_attention_heads + self.num_query_groups = self.num_attention_heads + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + if intermediate_size is None: + self.ffn_hidden_size = 4 * self.hidden_size + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + self.use_mem_eff_path = use_mem_eff_path + + # Below, "mmamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) + self.layers_block_type = ( + ["mamba"] + + (["mamba"] * 5 + ["hybrid"]) * 7 + + ["mamba"] * 4 + + ["hybrid"] + + ["mamba"] * 3 + + ["hybrid"] + + ["mamba"] * 2 + ) + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py new file mode 100644 index 00000000000..5c436a8ce53 --- /dev/null +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -0,0 +1,2129 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/zamba2/modular_zamba2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_zamba2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from itertools import cycle +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_mamba_ssm_available, +) +from .configuration_zamba2 import Zamba2Config + + +if is_mamba_ssm_available(): + #### from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined #### added + from mamba_ssm.ops.triton.selective_state_update import selective_state_update +else: + #### selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + selective_state_update = None, None, None #### added + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +#### is_fast_path_available = all( +#### (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +#### ) +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) #### added + + +logger = logging.get_logger(__name__) + + +class Zamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Zamba2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + return self.weight * hidden_states.to(input_dtype) + + +class Zamba2RotaryEmbedding(nn.Module): + def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + if config.use_long_context: + a = 8 + base = base * a ** (dim / (dim - 2)) + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + self.dtype = dtype + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + intermediate_size = config.mamba_expand * config.hidden_size + ssm_state_size = config.mamba_d_state + conv_kernel_size = config.mamba_d_conv + self.n_mamba_heads = config.n_mamba_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self._modules = {} + self._parameters = {} + self._buffers = {} + for i in range(config.num_hidden_layers): + self.conv_states += [ + torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + ] + cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) + self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] + if self.layers_block_type[i] == "hybrid": + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + +def layer_type_list(config: Zamba2Config): + """ + Returns list of layer ids containing hybrid layers + """ + ll = [] + i = 0 + for val in config.layers_block_type: + if val == "hybrid": + ll.append(i) + i += 1 + return ll + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Zamba2Attention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + + Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: + The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads. + The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + Additionally, replaced + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) + """ + + def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_mem_blocks=None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.attention_hidden_size = config.attention_hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.attention_head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.attention_hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.attention_hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.num_mem_blocks = num_mem_blocks + self.rope_theta = config.rope_theta + self.layer_block_map = layer_type_list(config) + + ### add to config: + # config.attention_hidden_size + # config.attention_head_dim + # config.max_position_embeddings + if config.use_shared_attention_lora: + self.linear_q_lora_A_list = nn.ParameterList([]) + self.linear_q_lora_B_list = nn.ParameterList([]) + self.linear_k_lora_A_list = nn.ParameterList([]) + self.linear_k_lora_B_list = nn.ParameterList([]) + self.linear_v_lora_A_list = nn.ParameterList([]) + self.linear_v_lora_B_list = nn.ParameterList([]) + + for i in range(self.num_mem_blocks): + linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) + linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + self.linear_q_lora_A_list.append(linear_q_lora_A) + self.linear_q_lora_B_list.append(linear_q_lora_B) + linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) + linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + self.linear_k_lora_A_list.append(linear_k_lora_A) + self.linear_k_lora_B_list.append(linear_k_lora_B) + linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) + linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + self.linear_v_lora_A_list.append(linear_v_lora_A) + self.linear_v_lora_B_list.append(linear_v_lora_B) + + if config.use_mem_rope: + self.rotary_emb = Zamba2RotaryEmbedding( + config, + self.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=self.rope_theta, + ) + + self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.use_shared_attention_lora: + lora_layer_idx = self.layer_dic[layer_idx] + linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx] + q_lora_output = linear_q_lora_A(hidden_states) + q_lora_output = linear_q_lora_B(q_lora_output) + query_states = self.q_proj(hidden_states) + query_states = query_states + q_lora_output + linear_k_lora_A = self.linear_k_lora_A_list[lora_layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[lora_layer_idx] + k_lora_output = linear_k_lora_A(hidden_states) + k_lora_output = linear_k_lora_B(k_lora_output) + key_states = self.k_proj(hidden_states) + key_states = key_states + k_lora_output + linear_v_lora_A = self.linear_v_lora_A_list[lora_layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[lora_layer_idx] + v_lora_output = linear_v_lora_A(hidden_states) + v_lora_output = linear_v_lora_B(v_lora_output) + value_states = self.v_proj(hidden_states) + value_states = value_states + v_lora_output + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.config.use_mem_rope: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + layer_idx = self.layer_block_map[layer_idx] + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward +# dropped use_sliding_windows from the arguments of self._flash_attention_forward +class Zamba2FlashAttention2(Zamba2Attention): + """ + Zamba2 flash attention module. This module inherits from `Zamba2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + if self.config.use_shared_attention_lora: + linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + q_lora_output = linear_q_lora_A(hidden_states) + q_lora_output = linear_q_lora_B(q_lora_output) + query_states = self.q_proj(hidden_states) + query_states = query_states + q_lora_output + linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + k_lora_output = linear_k_lora_A(hidden_states) + k_lora_output = linear_k_lora_B(k_lora_output) + key_states = self.k_proj(hidden_states) + key_states = key_states + k_lora_output + linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + v_lora_output = linear_v_lora_A(hidden_states) + v_lora_output = linear_v_lora_B(v_lora_output) + value_states = self.v_proj(hidden_states) + value_states = value_states + v_lora_output + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.config.use_mem_rope: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + layer_idx = self.layer_block_map[layer_idx] + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + softmax_scale = 1 / math.sqrt(self.head_dim / 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=softmax_scale, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: +# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention +class Zamba2SdpaAttention(Zamba2Attention): + """ + Zamba2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Zamba2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + def forward( + self, + hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Zamba2Model is using Zamba2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.use_shared_attention_lora: + linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] + linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] + q_lora_output = linear_q_lora_A(hidden_states) + q_lora_output = linear_q_lora_B(q_lora_output) + query_states = self.q_proj(hidden_states) + query_states = query_states + q_lora_output + linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] + linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] + k_lora_output = linear_k_lora_A(hidden_states) + k_lora_output = linear_k_lora_B(k_lora_output) + key_states = self.k_proj(hidden_states) + key_states = key_states + k_lora_output + linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] + linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] + v_lora_output = linear_v_lora_A(hidden_states) + v_lora_output = linear_v_lora_B(v_lora_output) + value_states = self.v_proj(hidden_states) + value_states = value_states + v_lora_output + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.config.use_mem_rope: + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + softmax_scale = 1 / math.sqrt(self.head_dim / 2) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + scale=softmax_scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + # pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if input_tensor.ndim == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + # if len(input_tensor.shape) == 3: + if input_tensor.ndim == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +class Zamba2Mamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: Zamba2Config, layer_idx: int = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * self.hidden_size) + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias # add this with default True + self.activation = "silu" + self.act = nn.SiLU() + + self.n_groups = config.mamba_ngroups + self.head_dim = config.mamba_headdim + self.num_heads = self.intermediate_size // self.head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf")) + self.time_step_min = config.time_step_min # add this, with same default as zamba1 + self.time_step_max = config.time_step_max # add this, with same default as zamba1 + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=True, + kernel_size=config.mamba_d_conv, + groups=self.conv_dim, + padding=config.mamba_d_conv - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.add_bias_linear, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated(self.intermediate_size, eps=1e-5) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # set up dimensions for reshapes later + + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads + + # getting projected states from cache if it exists + if cache_params is not None and cache_params.seqlen_offset > 0: + in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 + split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] + _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1) + + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + A = -torch.exp(self.A_log.float()) # (nheads,) + + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + out = self.out_proj(hidden_states)[:, None, ...] + # if no cache is found, calling the kernel + else: + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states) + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + if self.training and cache_params is None: + out, ssm_state = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=True, + **dt_limit_kwargs, + ) + + else: + gate, hidden_states_B_C, time_step = torch.split( + projected_states, + [self.intermediate_size, self.conv_dim, self.num_heads], + dim=-1, + ) + + # 1D Convolution + if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] + ) # (B, L, self.d_inner + 2 * ngroups * d_state) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2)[:, :seq_len] + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + time_step, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + scan_output = scan_output.view(batch_size, seq_len, -1) + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + # Gated MLP's linear projection + projected_states = self.in_proj(input_states.squeeze(1)) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 + _, _, gate, hidden_states, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # Convolution sequence transformation + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx].clone() + ssm_state = ssm_state.to(hidden_states.device) + if cache_params.seqlen_offset > 0: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + # handle batched generation - states are copied through + conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding + else: + hidden_states = hidden_states.transpose(1,2) + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + else: + ssm_state = torch.zeros( + (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) + hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) + A = -torch.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_params.seqlen_offset > 0: + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = torch.exp(dt[..., None] * A) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = dB * hidden_states[..., None] + + # State calculation + cache_params.ssm_states[self.layer_idx].copy_( + cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_min) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)] + + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # First, contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + + # Step 2: Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Step 3: Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + + # (right term of low-rank factorization of off-diagonal blocks; B terms) + + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] + # permute back B * decay states + states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) + if cache_params is not None and cache_params.seqlen_offset > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + + states_permuted = states.permute(0, 2, 1, 3, 4) + result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) + new_states = result.permute(0, 2, 1, 3, 4) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + # compute Yoff + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + if ssm_state is not None and cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +ZAMBA2_ATTENTION_CLASSES = { + "eager": Zamba2Attention, + "flash_attention_2": Zamba2FlashAttention2, + "sdpa": Zamba2SdpaAttention, +} + + +class Zamba2MLP(nn.Module): + def __init__(self, config: Zamba2Config, num_mem_blocks=None): + """ + Shared MLP layer. To the intermediate activations of the MLP, we add un-shared LoRA's, which + introduce some amount of diversification across the shared MLP layers. + """ + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_mem_blocks = num_mem_blocks + self.ffn_intermediate_size = config.intermediate_size + + self.act_fn = ACT2FN[config.hidden_act] + + def gated_act_fn(x): + x = torch.chunk(x, 2, dim=-1) + + return self.act_fn(x[0]) * x[1] + + self.gated_act_fn = gated_act_fn + + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias=config.add_bias_linear) + self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) + + if self.config.use_shared_block_lora: + self.gate_up_proj_lora_A_list = nn.ModuleList([]) + self.gate_up_proj_lora_B_list = nn.ModuleList([]) + for i in range(self.num_mem_blocks): + gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias=False) + gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias=False) + self.gate_up_proj_lora_A_list.append(gate_up_proj_lora_A) + self.gate_up_proj_lora_B_list.append(gate_up_proj_lora_B) + + layer_block_map = layer_type_list(config) + self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} + + def forward(self, hidden_state, layer_idx=None): + if self.config.use_shared_block_lora: + layer_idx = self.layer_dic[layer_idx] + gate_up_proj_lora_A = self.gate_up_proj_lora_A_list[layer_idx] + gate_up_proj_lora_B = self.gate_up_proj_lora_B_list[layer_idx] + lora_output = gate_up_proj_lora_A(hidden_state) + lora_output = gate_up_proj_lora_B(lora_output) + intermediate_state = self.gate_up_proj(hidden_state) + hidden_state = intermediate_state + lora_output + else: + hidden_state = self.gate_up_proj(hidden_state) + + hidden_state = self.gated_act_fn(hidden_state) + output = self.down_proj(hidden_state) + return output + + +def count_mem_blocks_in_config(config: Zamba2Config): + """ + Count number of shared blocks + """ + num_gs = 0 + for val in config.layers_block_type: + if val == "hybrid": + num_gs += 1 + return num_gs + + +class Zamba2AttentionDecoderLayer(nn.Module): + def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None): + super().__init__() + num_gs = count_mem_blocks_in_config(config) + self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation]( + config, layer_idx=-1, num_mem_blocks=num_gs + ) + self.feed_forward = Zamba2MLP(config, num_mem_blocks=num_gs) + self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. + This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The + concatenated tensor is then used as input of the pre-attention RMSNorm + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states, layer_idx) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Zamba2Mamba2DecoderLayer(nn.Module): + def __init__(self, config: Zamba2Config, layer_idx: int): + super().__init__() + self.mamba = Zamba2Mamba2Mixer(config=config, layer_idx=layer_idx) + self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: Optional[torch.Tensor] = None, + layer_idx: int = None, + attention_mask: Optional[torch.Tensor] = None, + causal_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + transformer_hidden_states: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + residual = hidden_states + + # `transformer_hidden_states` is the output from shared transformer + linear layer (see fig. 2 in https://arxiv.org/pdf/2405.16712). + # `transformer_hidden_states` is then added to the input to the mamba layer below (as described in eq. (6) of https://arxiv.org/pdf/2405.16712). + hidden_states = ( + hidden_states + transformer_hidden_states if transformer_hidden_states is not None else hidden_states + ) + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_value, + attention_mask=attention_mask, + ) + + self_attn_weights = None + + # residual connection after mamba + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + +class Zamba2HybridLayer(nn.Module): + def __init__(self, shared_transf: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2Mamba2DecoderLayer): + super().__init__() + self.shared_transf = shared_transf + self.linear = linear + self.mamba_decoder = mamba + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: Optional[torch.Tensor] = None, + layer_idx: int = None, + attention_mask: Optional[torch.Tensor] = None, + causal_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with + hidden activations to form the input of the shared transformer layer. + layer_idx (`int`): layer number. + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + layer_outputs = self.shared_transf( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + transformer_hidden_states = layer_outputs[0] + + if output_attentions: + self_attn_weights = layer_outputs[1] + + transformer_hidden_states = self.linear(transformer_hidden_states) + + layer_outputs = self.mamba_decoder( + hidden_states, + transformer_hidden_states=transformer_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + if output_attentions: + layer_outputs = (layer_outputs[0], self_attn_weights) + layer_outputs[2:] + + return layer_outputs + + +ZAMBA2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Zamba2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", + ZAMBA2_START_DOCSTRING, +) +class Zamba2PreTrainedModel(PreTrainedModel): + config_class = Zamba2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2Mamba2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _is_stateful = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, Zamba2Mamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim + dt = torch.exp( + torch.rand(num_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + @classmethod + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + hard_check_only: bool = False, + check_device_map: bool = False, + ): + """ + Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 2 by default on Zamba2 models. + Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba2 v1. + + Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` + as we do not want to disable Flash Attention 2 in Zamba2. + """ + """ + Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 2 by default on Zamba2 models. + Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba2 v1. + """ + config = super()._check_and_enable_flash_attn_2( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) + + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "flash_attention_2": + config._attn_implementation = "eager" + """ + Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` + as we do not want to disable Flash Attention 2 in Zamba2. + """ + return super(ZambaPreTrainedModel)._check_and_enable_flash_attn_2( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) + + return config + + +_CONFIG_FOR_DOC = "Zamba2Config" + + +ZAMBA2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. + See the `HybridMambaAttentionDynamicCache` class for more details. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", + ZAMBA2_START_DOCSTRING, +) +class Zamba2Model(Zamba2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ZambaDecoderLayer`] + + Args: + config: ZambaConfig + """ + + def __init__(self, config: Zamba2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + blocks = [Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)] + mamba_layers = [] + linear_layers = [] + self.layers_block_type = config.layers_block_type + for i in range(config.num_hidden_layers): + if config.layers_block_type[i] == "mamba": + mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + elif config.layers_block_type[i] == "hybrid": + linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)) + mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + mamba_layers = iter(mamba_layers) + linear_layers = iter(linear_layers) + blocks = cycle(blocks) + layers = [] + self._tied_weights_keys = [] + for layer_id, layer_type in enumerate(self.layers_block_type): + if layer_type == "hybrid": + prefix_name = f"layers.{layer_id}." + tied_keys = [ + "shared_transf.self_attn.q_proj.weight", + "shared_transf.self_attn.k_proj.weight", + "shared_transf.self_attn.v_proj.weight", + "shared_transf.self_attn.o_proj.weight", + "shared_transf.feed_forward.gate_proj.weight", + "shared_transf.feed_forward.up_proj.weight", + "shared_transf.feed_forward.down_proj.weight", + "shared_transf.input_layernorm.weight", + "shared_transf.pre_ff_layernorm.weight", + ] + self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + layers.append(Zamba2HybridLayer(next(blocks), next(linear_layers), next(mamba_layers))) + else: + layers.append(next(mamba_layers)) + self.layers = nn.ModuleList(layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + original_hidden_states = torch.clone(inputs_embeds) + # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer + + if use_cache and past_key_values is None: + logger.warning_once( + "Zamba2 requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + original_hidden_states, + layer_idx, + attention_mask, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = layer( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + causal_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +# Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba2, JAMBA->ZAMBA2 +class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin): + def __init__(self, config: Zamba2Config): + super().__init__(config) + self.model = Zamba2Model(config) + self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys] + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Zamba2ForCausalLM + + >>> model = Zamba2ForCausalLM.from_pretrained("Zyphra/Zamba2-7B-v1") + >>> tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-v1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + + empty_past_kv = past_key_values is None + + # Omit tokens covered by past_key_values + if not empty_past_kv: + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], dtype=self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "num_logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + +@add_start_docstrings( + """ + The Zamba2 Model with a sequence classification head on top (linear layer). + + [`Zamba2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + ZAMBA2_START_DOCSTRING, +) +class Zamba2ForSequenceClassification(Zamba2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Zamba2Model(config) + self._tied_weights_keys = self.model._tied_weights_keys + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 7e9dccf713a..09ac5b18d53 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -49,7 +49,6 @@ is_torchdynamo_compiling, ) - if is_mamba_ssm_available(): from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined @@ -63,10 +62,8 @@ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) from ..zamba.modeling_zamba import ( - ZambaAttentionDecoderLayer, ZambaMambaDecoderLayer, ZambaAttention, - HybridLayer, ZambaForCausalLM, ZambaForSequenceClassification, ZambaModel, @@ -319,9 +316,11 @@ def layer_type_list(config: Zamba2Config): Returns list of layer ids containing hybrid layers """ ll = [] - for i, val in enumerate(config.layers_block_type): + i = 0 + for val in config.layers_block_type: if val == 'hybrid': ll.append(i) + i += 1 return ll @@ -334,7 +333,8 @@ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): Assumes that we only have tensors of either size 4 or 3 """ - pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + # pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if input_tensor.ndim == 4 else (0, 0, 0, pad_size, 0, 0) return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) @@ -349,7 +349,8 @@ def reshape_into_chunks(input_tensor, pad_size, chunk_size): # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] input_tensor = pad_tensor_by_size(input_tensor, pad_size) - if len(input_tensor.shape) == 3: + # if len(input_tensor.shape) == 3: + if input_tensor.ndim == 3: # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) else: @@ -1317,12 +1318,14 @@ def forward(self, hidden_state, layer_idx = None): return output -class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer): +class Zamba2AttentionDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) + super().__init__() num_gs = count_mem_blocks_in_config(config) self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=-1, num_mem_blocks = num_gs) self.feed_forward = Zamba2MLP(config, num_mem_blocks = num_gs) + self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -1388,11 +1391,82 @@ class Zamba2Mamba2DecoderLayer(ZambaMambaDecoderLayer): def __init__(self, config: Zamba2Config, layer_idx: int): super().__init__(config, layer_idx) self.mamba = Zamba2Mamba2Mixer(config=config, layer_idx=layer_idx) + self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class Zamba2HybridLayer(nn.Module): + def __init__(self, shared_transf: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2Mamba2DecoderLayer): + super().__init__() + self.shared_transf = shared_transf + self.linear = linear + self.mamba_decoder = mamba + + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: Optional[torch.Tensor] = None, + layer_idx: int = None, + attention_mask: Optional[torch.Tensor] = None, + causal_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with + hidden activations to form the input of the shared transformer layer. + layer_idx (`int`): layer number. + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + + layer_outputs = self.shared_transf( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + transformer_hidden_states = layer_outputs[0] -class Zamba2HybridLayer(HybridLayer): - pass + if output_attentions: + self_attn_weights = layer_outputs[1] + transformer_hidden_states = self.linear(transformer_hidden_states) + + layer_outputs = self.mamba_decoder( + hidden_states, + transformer_hidden_states=transformer_hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + if output_attentions: + layer_outputs = (layer_outputs[0], self_attn_weights) + layer_outputs[2:] + + return layer_outputs ZAMBA2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1548,8 +1622,11 @@ class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): """ def __init__(self, config: Zamba2Config): - super().__init__(config) + Zamba2PreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) blocks = [Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)] mamba_layers = [] linear_layers = [] @@ -1580,12 +1657,15 @@ def __init__(self, config: Zamba2Config): "shared_transf.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - layers.append(HybridLayer(next(blocks), next(linear_layers), next(mamba_layers))) + layers.append(Zamba2HybridLayer(next(blocks), next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) self.layers = nn.ModuleList(layers) self._attn_implementation = config._attn_implementation + self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() From 685906a0f541d731f183645d3fa07f268fd89b09 Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 5 Nov 2024 04:58:51 +0000 Subject: [PATCH 07/73] generated modeling and configuration --- .../models/zamba2/configuration_zamba2.py | 18 +- .../models/zamba2/modeling_zamba2.py | 20 +- .../models/zamba2/modular_zamba2.py | 43 +- utils/modular_model_converter.py | 1656 +++++++---------- utils/modular_model_converter.textClipping | Bin 0 -> 259 bytes 5 files changed, 716 insertions(+), 1021 deletions(-) create mode 100644 utils/modular_model_converter.textClipping diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 7801bd8e3ab..f2457ac400f 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -156,6 +156,14 @@ def __init__( use_long_context=False, **kwargs, ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.tie_word_embeddings = tie_word_embeddings @@ -225,7 +233,7 @@ def __init__( self.num_logits_to_keep = num_logits_to_keep self.use_mem_eff_path = use_mem_eff_path - # Below, "mmamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) + # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) self.layers_block_type = ( ["mamba"] + (["mamba"] * 5 + ["hybrid"]) * 7 @@ -235,11 +243,3 @@ def __init__( + ["hybrid"] + ["mamba"] * 2 ) - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 5c436a8ce53..fa73fcf5905 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -741,7 +741,7 @@ def segment_sum(input_tensor): return tensor_segsum -class Zamba2Mamba2Mixer(nn.Module): +class Zamba2MambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) @@ -1299,10 +1299,10 @@ def forward( return outputs -class Zamba2Mamba2DecoderLayer(nn.Module): +class Zamba2MambaDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, layer_idx: int): super().__init__() - self.mamba = Zamba2Mamba2Mixer(config=config, layer_idx=layer_idx) + self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx) self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layer_idx = layer_idx @@ -1368,7 +1368,7 @@ def forward( class Zamba2HybridLayer(nn.Module): - def __init__(self, shared_transf: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2Mamba2DecoderLayer): + def __init__(self, shared_transf: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer): super().__init__() self.shared_transf = shared_transf self.linear = linear @@ -1467,7 +1467,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): config_class = Zamba2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2Mamba2DecoderLayer"] + _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = False @@ -1484,7 +1484,7 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Zamba2Mamba2Mixer): + elif isinstance(module, Zamba2MambaMixer): module.A_log._no_weight_decay = True module.D._no_weight_decay = True @@ -1617,10 +1617,10 @@ def _check_and_enable_flash_attn_2( ) class Zamba2Model(Zamba2PreTrainedModel): """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ZambaDecoderLayer`] + Model consisting of *config.num_hidden_layers* layers. Args: - config: ZambaConfig + config: Zamba2Config """ def __init__(self, config: Zamba2Config): @@ -1635,10 +1635,10 @@ def __init__(self, config: Zamba2Config): self.layers_block_type = config.layers_block_type for i in range(config.num_hidden_layers): if config.layers_block_type[i] == "mamba": - mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i)) elif config.layers_block_type[i] == "hybrid": linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)) - mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i)) mamba_layers = iter(mamba_layers) linear_layers = iter(linear_layers) blocks = cycle(blocks) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 09ac5b18d53..5398fa2d7b4 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -49,6 +49,7 @@ is_torchdynamo_compiling, ) + if is_mamba_ssm_available(): from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined @@ -74,9 +75,10 @@ ) from ...configuration_utils import PretrainedConfig -logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "Zamba2Config" +_CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B" + +logger = logging.get_logger(__name__) class Zamba2Config(PretrainedConfig): @@ -217,6 +219,13 @@ def __init__( use_long_context=False, **kwargs, ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -288,17 +297,9 @@ def __init__( self.use_mem_eff_path = use_mem_eff_path - # Below, "mmamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) + # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) self.layers_block_type = ['mamba'] + (['mamba'] * 5 + ['hybrid']) * 7 + ['mamba'] * 4 + ['hybrid'] + ['mamba'] * 3 + ['hybrid'] + ['mamba'] * 2 - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - def count_mem_blocks_in_config(config: Zamba2Config): """ @@ -851,7 +852,7 @@ def forward( } -class Zamba2Mamba2Mixer(nn.Module): +class Zamba2MambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) @@ -1387,15 +1388,15 @@ def forward( return outputs -class Zamba2Mamba2DecoderLayer(ZambaMambaDecoderLayer): +class Zamba2MambaDecoderLayer(ZambaMambaDecoderLayer): def __init__(self, config: Zamba2Config, layer_idx: int): super().__init__(config, layer_idx) - self.mamba = Zamba2Mamba2Mixer(config=config, layer_idx=layer_idx) + self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx) self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) class Zamba2HybridLayer(nn.Module): - def __init__(self, shared_transf: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2Mamba2DecoderLayer): + def __init__(self, shared_transf: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer): super().__init__() self.shared_transf = shared_transf self.linear = linear @@ -1492,7 +1493,7 @@ def forward( class Zamba2PreTrainedModel(ZambaPreTrainedModel): _supports_flash_attn_2 = True # Leaving this commented out for now until testing - # _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2Mamba2DecoderLayer"] + # _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] def _init_weights(self, module): std = self.config.initializer_range @@ -1504,7 +1505,7 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - elif isinstance(module, Zamba2Mamba2Mixer): + elif isinstance(module, Zamba2MambaMixer): module.A_log._no_weight_decay = True module.D._no_weight_decay = True @@ -1615,10 +1616,10 @@ def _check_and_enable_flash_attn_2( ) class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`ZambaDecoderLayer`] + Model consisting of *config.num_hidden_layers* layers. Args: - config: ZambaConfig + config: Zamba2Config """ def __init__(self, config: Zamba2Config): @@ -1633,10 +1634,10 @@ def __init__(self, config: Zamba2Config): self.layers_block_type = config.layers_block_type for i in range(config.num_hidden_layers): if config.layers_block_type[i] == "mamba": - mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i)) elif config.layers_block_type[i] == "hybrid": linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)) - mamba_layers.append(Zamba2Mamba2DecoderLayer(config, layer_idx=i)) + mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i)) mamba_layers = iter(mamba_layers) linear_layers = iter(linear_layers) blocks = cycle(blocks) diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index b1dfa18a7a9..bda143c2577 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -17,14 +17,13 @@ import importlib import os import re -from abc import ABC, abstractmethod from collections import defaultdict, deque -from typing import Dict, Set +from typing import Dict, List, Optional, Set import libcst as cst from check_copies import run_ruff from create_dependency_mapping import find_priority_list -from libcst import ClassDef, CSTVisitor +from libcst import ClassDef, CSTTransformer, CSTVisitor from libcst import matchers as m from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider @@ -35,6 +34,13 @@ logger = logging.get_logger(__name__) +# This is used to avoid overwriting these top-level assignments even if they are in the dependency graph. Otherwise, the +# value from the dependency is used, then mapped to current name convention, resulting in wrong value. +# The corresponding mapped value is used to define the file target for the assignment +ASSIGNMENTS_TO_KEEP = { + "_CHECKPOINT_FOR_DOC": "modeling", +} + AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from {relative_path}. # Do NOT edit this file manually as any edits will be overwritten by the generation of @@ -55,23 +61,137 @@ def get_module_source_from_name(module_name: str) -> str: return source_code -def preserve_case_replace(text, patterns: dict, default_name: str): - # Create a regex pattern to match all variations - regex_pattern = "|".join(re.escape(key) for key in patterns.keys()) - compiled_regex = re.compile(regex_pattern, re.IGNORECASE) +class ClassFinder(CSTVisitor): + """A visitor class which analyses a module, creating a mapping of dependencies between classes and functions. + For example if the visited code has + ```python3 + def init_value(): return 1 - def replace(match): - word = match.group(0) - result = patterns.get(word, default_name) - return result + class LlamaModel(PreTrainedModel): + def __init__(self): + super().__init__(self) + self.value = init_value() + ``` + then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]} - return compiled_regex.sub(replace, text) + The dependency mapping is updated via the `visit_Name`, `visit_Arg` and `visit_Decorator`. This is very broad, and by + checking the parent node, or the scope of a `cst.Name` or `cst.Arg` or `cst.Decorator` we are able to map the + dependence parent -> child. + When visiting such nodes, we update the dependency of the parent node, to take into account the visited node. -def convert_to_camelcase(text, old_name: str, default_old_name: str): - # Regex pattern to match consecutive uppercase letters and lowercase the first set - result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1) - return result + All `visit_XXX` correspond to the code executed when vising the cst.Node of type XXX. + """ + + METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) + + def __init__(self, python_module: cst.Module): + # fmt: off + self.python_module: cst.Module = python_module # original cst.Module being visited + self.classes: Dict[str, cst.ClassDef] = {} # stores a mapping from classname to the cst.Node + self.imports = {} # stores all import statements + self.function_def = {} # stores global scope function definition + self.assignments = {} # LLAMA_DOCSTRING + self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"] + self.first_lvl_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"] + # fmt: on + + def _update_class_dependency(self, name, value): + """Update the dependency mapping for `name` with `value` by appending the previous + dependencies to the new `value`. + """ + dep = set(self.first_lvl_dependency_mapping.get(name, set())) | set({value}) + self.first_lvl_dependency_mapping[name] = dep + + dep = set(self.class_dependency_mapping.get(value, set())) + dep |= set(self.class_dependency_mapping.get(name, {})) | set({value}) + self.class_dependency_mapping[name] = dep + + def visit_ClassDef(self, node: ClassDef) -> None: + """We don't have non global scope class defs in transformers. Here we add the inheritance dependencies""" + self.classes[node.name.value] = node + for k in node.bases: # deal with inheritance + base_name = self.python_module.code_for_node(k) + self._update_class_dependency(node.name.value, base_name) + + def visit_SimpleStatementLine(self, node): + """ + Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements + are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. + """ + if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( + self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() + ): + left_hand_side = node.body[0].targets[0].target + if hasattr(left_hand_side, "value"): + if left_hand_side.value not in ASSIGNMENTS_TO_KEEP.keys(): + self.assignments[left_hand_side.value] = node + else: + for idx, target in enumerate(list(left_hand_side.elements)): + if target.value.value not in ASSIGNMENTS_TO_KEEP.keys(): + self.assignments[target.value.value] = node.body[0].value.elements[idx].value + if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): + self.imports[node.body[0].names] = node + + def visit_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.function_def[node.name.value] = node + + def leave_If(self, node): + for stmt in node.body.body: + if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): + self.imports[stmt.body[0].names] = node + + def leave_Name(self, node): + if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys(): + parent = self.get_metadata(cst.metadata.ScopeProvider, node) + if not isinstance(parent, cst.metadata.scope_provider.GlobalScope): + self._update_class_dependency(parent._name_prefix.split(".")[0], node.value) + + def leave_Arg(self, node): + if m.matches(node.value, m.Name()): + parent = self.get_metadata(ParentNodeProvider, node) + if m.matches(parent, m.ClassDef()) and parent.bases: + self._update_class_dependency(parent.name.value, node.value.value) + + def leave_Dict(self, node): + parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent, m.Assign(targets=[m.AssignTarget()])): + name = parent.targets[0].target.value + if name in self.assignments: + for k in node.elements: + dep_name = k.value.value + if dep_name in self.classes: + self._update_class_dependency(name, dep_name) + + def leave_Decorator(self, node): + if hasattr(node.decorator, "args"): + for k in node.decorator.args: + if m.matches(k.value, m.Call(func=m.Attribute(value=m.Name()))): # and k.value.func.value.value: + if k.value.func.value.value not in self.assignments: + raise ValueError( + f"We detected a call to {k.value.func.value.value}, but it was not assigned. See the list of assigments {self.assignments.keys()}" + ) + parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) + scope = self.get_metadata(cst.metadata.ScopeProvider, node) + name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value + self._update_class_dependency(name, k.value.func.value.value) + elif m.matches(k, m.Arg(value=m.Name())) and k.value.value in self.assignments: + parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) + scope = self.get_metadata(cst.metadata.ScopeProvider, node) + name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value + self._update_class_dependency(name, k.value.value) + + def leave_Module(self, node): + """When leaving the module, we store the position of each global scoped node (Assigns, function def and class def) + to allow sorting the dependencies based on their position in the code. We use the PositionProvider metadata wrapper for this. + """ + self.global_nodes = {**self.assignments, **self.classes, **self.function_def} + # now sort the class dependency_mapping based on the position of the nodes + self.class_start_line = {} + for id, node in self.global_nodes.items(): + self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line class ReplaceNameTransformer(m.MatcherDecoratableTransformer): @@ -90,6 +210,8 @@ def __init__( new_name, given_old_name=None, given_new_name=None, + old_class_name: str = None, + new_class_name: str = None, ): super().__init__() self.old_name = old_name @@ -110,17 +232,70 @@ def __init__( self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "") if self.default_old_name.isupper(): self.default_old_name = self.default_old_name.capitalize() + if new_class_name is not None and old_class_name is not None and old_class_name not in self.patterns: + # In last recourse, when the suffix of the new class is not the same as the old class, + # and if the old and new classes start with the default name, we keep the default class name + # and replace the old suffix with the new one. + # Useful when we have a class like `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration` + # where a model extends another model, but is used for a different task. + if old_class_name.startswith(self.default_old_name) and new_class_name.startswith(self.default_name): + self.patterns[old_class_name[len(self.default_old_name) :]] = new_class_name[len(self.default_name) :] + + def preserve_case_replace(self, text): + # Create a regex pattern to match all variations + regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys()) + compiled_regex = re.compile(regex_pattern, re.IGNORECASE) + + def replace(match): + word = match.group(0) + result = self.patterns.get(word, self.default_name) + return result + + return compiled_regex.sub(replace, text) + + def convert_to_camelcase(self, text): + # Regex pattern to match consecutive uppercase letters and lowercase the first set + result = re.sub( + rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1 + ) + return result @m.leave(m.Name() | m.SimpleString() | m.Comment()) def replace_name(self, original_node, updated_node): if re.findall(r"# Copied from", updated_node.value): return cst.RemoveFromParent() - update = preserve_case_replace(updated_node.value, self.patterns, self.default_name) + update = self.preserve_case_replace(updated_node.value) return updated_node.with_changes(value=update) def leave_ClassDef(self, original_node, updated_node): - new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name) - return updated_node.with_changes(name=cst.Name(new_name)) + return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value))) + + +def find_classes_in_file( + module: cst.Module, + old_id="llama", + new_id="gemma", + given_old_name=None, + given_new_name=None, + old_class_name=None, + new_class_name=None, +): + """Helper function to rename and then parse a source file using the ClassFinder""" + transformer = ReplaceNameTransformer( + old_id, + new_id, + given_old_name=given_old_name, + given_new_name=given_new_name, + old_class_name=old_class_name, + new_class_name=new_class_name, + ) + new_module = module.visit(transformer) + + wrapper = MetadataWrapper(new_module) + + class_finder = ClassFinder(new_module) + wrapper.visit(class_finder) + return class_finder DOCSTRING_NODE = m.SimpleStatementLine( @@ -237,12 +412,13 @@ def merge_docstrings(original_docstring, updated_docstring): class SuperTransformer(cst.CSTTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None): + def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None): self.python_module = python_module self.original_methods = original_methods self.updated_methods = updated_methods self.all_assign_target = {} self.deleted_targets = {} # child node can delete some arguments + self.class_name = class_name self.all_bases = all_bases or [] self.transformer = ReplaceMethodCallTransformer(set(self.all_bases)) @@ -261,6 +437,7 @@ def update_body(self, existing_body, new_statements): if m.matches(node, m.SimpleStatementLine(body=[m.Del()])): target = self.python_module.code_for_node(node.body[0].target) self.deleted_targets[target] = node + continue for stmt in existing_body: if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): @@ -270,9 +447,6 @@ def update_body(self, existing_body, new_statements): continue if target in self.all_assign_target: stmt = self.all_assign_target[target] - # Skip the docstring (will be added later on, at the beginning) - elif m.matches(stmt, DOCSTRING_NODE): - continue comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() deduplicated_new_body.append(stmt) @@ -282,47 +456,17 @@ def update_body(self, existing_body, new_statements): code = self.python_module.code_for_node(node) comment_less_code = re.sub(r"#.*", "", code).strip() comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if node not in deduplicated_new_body and comment_less_code not in existing_nodes: + if ( + node not in deduplicated_new_body + and "super().__init__" not in comment_less_code + and comment_less_code not in existing_nodes + ): if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])): - deduplicated_new_body.append(node) + # HACK here to fix the pos_init() that has to be last we kinda do this. + deduplicated_new_body = deduplicated_new_body[:-1] + [node] + deduplicated_new_body[-1:] existing_nodes.add(comment_less_code) - - deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body) - return deduplicated_new_body - def _fix_post_init_location(self, new_body: list[cst.CSTNode]): - """Fix the location of the `post_init()` in the new body, if we added statements after the call to - `super()` (it needs to be the very last statement called)""" - # Fix the post_init() that has to be last - for i, node in enumerate(new_body): - code = self.python_module.code_for_node(node) - comment_less_code = re.sub(r"#.*", "", code).strip() - comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if "self.post_init(" in comment_less_code and i < len(new_body) - 1: - # Remove it and add it again at the end - new_body.pop(i) - new_body.append(node) - break - return new_body - - def _fix_init_location(self, new_body): - """Fix the location of the `super().__init__()` in the new body, if we had new statements before it.""" - start_index = 0 - for i, node in enumerate(new_body): - if m.matches(node, DOCSTRING_NODE) and i == start_index: - start_index += 1 - continue - code = self.python_module.code_for_node(node) - comment_less_code = re.sub(r"#.*", "", code).strip() - comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if "super().__init__" in comment_less_code and i > start_index: - # Remove it and add it again at the top after the docstrings - node = new_body.pop(i) - new_body = new_body[:start_index] + [node] + new_body[start_index:] - break - return new_body - def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: """Updates the body of the input `node`'s `func_name` function by replacing calls to super().func_name() with the source code of the parent class' `func_name`. @@ -335,11 +479,10 @@ def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CS new_body = [] has_super_call = False - for i, expr in enumerate(node.body): + for expr in node.body: if is_call_to_super(expr, func_name): has_super_call = True - new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :])) - new_body = self._fix_init_location(new_body) + new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body)) else: expr = expr.visit(self.transformer) if m.matches(expr, DOCSTRING_NODE): @@ -381,463 +524,11 @@ def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> c return updated_node -def find_all_dependencies( - dependency_mapping: Dict[str, set], - start_entity: str | None = None, - initial_dependencies: set | None = None, - initial_checked_dependencies: set | None = None, - return_parent: bool = False, -) -> list | set: - """Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of - BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. - - Args: - dependency_mapping (`Dict[str, set]`): - A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names, - a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called - in `foo`'s definition. - start_entity (str | None, *optional*): - A key of `dependency_mapping`, indicating from which entity to start the search. - initial_dependencies (set | None, *optional*): - If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue - from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. - initial_checked_dependencies (set | None, *optional*): - If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. - return_parent (bool, *optional*): - If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note - that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs. - Returns: - A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`. - - Example: - Given the following structure in the `modular_xxx.py` file: - ``` - def foo1(): - pass - - def foo2(): - pass - - def bar(): - foo1() - - def foobar(): - bar() - foo2() - - class MyLayer(SomeOtherModelLayer): - def forward(...): - foobar() - ``` - and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: - ``` - dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} - find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True) - >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] - ``` - That is, all the functions needed (and potentially their immediate parent) so that the function to be added - in MyLayer (`foobar`) can work correctly. - """ - if initial_dependencies is None and start_entity is not None: - initial_dependencies = dependency_mapping[start_entity] - if initial_checked_dependencies is None: - initial_checked_dependencies = set() - - dependency_queue = deque(initial_dependencies) - all_dependencies = set() - all_dependencies_with_parent = [] - checked_dependencies = set(initial_checked_dependencies) - parents = {initial_dep: start_entity for initial_dep in initial_dependencies} - while len(dependency_queue) > 0: - # Pick element to visit - current = dependency_queue.popleft() - if current not in checked_dependencies: - # Add the dependencies - all_dependencies.add(current) - all_dependencies_with_parent += [(current, parents[current])] - if current in dependency_mapping.keys(): - # Update dependency queue - dependency_queue.extend(dependency_mapping[current]) - parents.update({dep: current for dep in dependency_mapping[current]}) - # add visited node to the list - checked_dependencies.add(current) - - if not return_parent: - return all_dependencies - # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) - return all_dependencies_with_parent - - -# These top-level variables will always use the value in the `modular_xxx.py` file -ASSIGNMENTS_TO_KEEP = { - "_CHECKPOINT_FOR_DOC", -} - - -class ClassDependencyMapper(CSTVisitor): - """A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of - `global_names`. - """ - - def __init__( - self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None - ): - super().__init__() - self.class_name = class_name - self.dependencies = set() - self.global_names = global_names - self.objects_imported_from_modeling = ( - set() if objects_imported_from_modeling is None else objects_imported_from_modeling - ) - - def visit_Name(self, node): - if ( - node.value != self.class_name - and node.value in self.global_names - and node.value not in self.objects_imported_from_modeling - ): - self.dependencies.add(node.value) - - -def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set: - """Create immediate dependencies for a class node based on the `global_names`.""" - temp_module = cst.Module(body=[node]) - visitor = ClassDependencyMapper(node.name.value, global_names) - temp_module.visit(visitor) - return visitor.dependencies - - -def augmented_dependencies_for_class_node( - node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None -) -> set: - """Create augmented dependencies for a class node based on a `mapper`. - Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. - """ - temp_module = cst.Module(body=[node]) - visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling) - temp_module.visit(visitor) - return mapper.augment_dependencies(visitor.dependencies) - - -# All the potential file types to create -ALL_FILE_TYPES = ( - "modeling", - "configuration", - "tokenization", - "processing", - "image_processing", - "feature_extractor", -) - - -class ModuleMapper(CSTVisitor, ABC): - """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments. - Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in - `self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`). - It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the - modeling files that will be visited. +def replace_call_to_super( + class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str] +): """ - - METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) - - def __init__(self, python_module: cst.Module): - # fmt: off - self.python_module: cst.Module = python_module # original cst.Module being visited - self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes (it will be ordered by default!!) - self.imports = [] # stores all import statements - self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes - self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition) - self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes - self.current_function = None # this keeps track of the current module-scope function - self.current_assignment = None # this keeps track of the current module-scope assignment - # this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency - self.objects_imported_from_modeling = set() - # regex pattern joining every possible file type - self.match_patterns = "|".join(ALL_FILE_TYPES) - # fmt: on - - def visit_ImportFrom(self, node): - """This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have - `from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs - to be added (because it will be part of the imports)""" - import_module = self.python_module.code_for_node(node.module) - import_statement = "." * len(node.relative) + import_module - if re.search(rf"^\.({self.match_patterns})_.*", import_statement): - for imported_object in node.names: - # If an alias is present, we record it and not the original name - if imported_object.evaluated_alias is not None: - self.objects_imported_from_modeling.add(imported_object.evaluated_alias) - else: - self.objects_imported_from_modeling.add(imported_object.evaluated_name) - - def visit_SimpleStatementLine(self, node): - """ - Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements - are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. - """ - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - simple_top_level_assign_structure = m.SimpleStatementLine( - body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] - ) - if m.matches(parent_node, m.Module()): - if m.matches(node, simple_top_level_assign_structure): - left_hand_side = node.body[0].targets[0].target.value - self.current_assignment = left_hand_side - self.assignments[left_hand_side] = node - elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): - self.imports.append(node) - - def leave_SimpleStatementLine(self, node): - # No need to check for the parent here -> everytime we exit one, it should be None anyway independently of where the - # SimpleStatement is located - self.current_assignment = None - - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.current_function = node.name.value - self.functions[node.name.value] = node - - def leave_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.current_function = None - - def visit_If(self, node): - for stmt in node.body.body: - if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): - self.imports.append(node) - - def visit_ClassDef(self, node: ClassDef) -> None: - """Record class nodes to create their dependencies at the end.""" - self.classes[node.name.value] = node - - def visit_Name(self, node: cst.Call): - """This is used to create a mapping from module-scope functions and assignments to objects used inside them.""" - if self.current_function is not None: - self.object_dependency_mapping[self.current_function].add(node.value) - if self.current_assignment is not None: - self.object_dependency_mapping[self.current_assignment].add(node.value) - - def leave_Module(self, node): - """When leaving the module, we store the position of each global scoped node to allow sorting the dependencies - based on their position in the code later. We use the PositionProvider metadata wrapper for this. - We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in - `self.global_nodes`. - """ - # assign all nodes - self.global_nodes = {**self.assignments, **self.classes, **self.functions} - # now sort the class dependency_mapping based on the position of the nodes - self.start_lines = {} - for id, node in self.global_nodes.items(): - self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line - - # Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that - # are not part of the recorded objects (i.e. built-in variables, imports, etc) - global_objects = set(self.global_nodes.keys()) - for object_name, dependencies in self.object_dependency_mapping.items(): - self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} - - def _compute_recursive_object_dependencies(self) -> dict[str, set]: - """Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the - following file: - ``` - def foo(): - pass - - def bar(): - foo() - - def test(): - bar() - ``` - this visitor can only record immediate dependencies, i.e. it will record the following - `self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create - the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`. - """ - recursive_dependencies = {} - for object_name in self.object_dependency_mapping.keys(): - all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name) - recursive_dependencies[object_name] = all_dependencies - return recursive_dependencies - - def augment_dependencies(self, dependencies: set[str]) -> set[str]: - """For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and - **assignments** present in the `dependencies`. - """ - new_dependencies = dependencies.copy() - # Go through the set of dependencies - for dep in tuple(dependencies): - if dep in self.object_recursive_dependency_mapping.keys(): - new_dependencies.update(self.object_recursive_dependency_mapping[dep]) - return new_dependencies - - def compute_class_dependencies(self): - """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies.""" - self.class_dependency_mapping = {} - for class_name, class_node in self.classes.items(): - dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) - # Correctly augment class dependencies with all needed objects - self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies) - - @abstractmethod - def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: - raise NotImplementedError - - -class ModelFileMapper(ModuleMapper): - """A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file - in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. - For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes - care of correctly merging dependencies, then finalizes all dependency graph computations. - Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified. - For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies - of the modeling files as well. - """ - - def __init__(self, python_module: cst.Module): - super().__init__(python_module) - - def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]: - """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that - will be created based on the modular. - """ - relative_order = {} - idx = 0 - classes = sorted( - [dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x] - ) - # This is because for merged dependencies, we only have relative order in the other visited file, so we need - # to track dependency order relative to a given class - if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): - raise ValueError("Cannot correctly find the relative order of the dependencies.") - - remaining_dependencies = missing_dependencies.copy() - - # Start by tracking relative order class by class - for class_name in classes: - class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies) - original_dependencies = [] - merged_dependencies = [] - # We need to differentiate between nodes that were already present (we can get relative order globally) and - # nodes that were merged (we can get relative order only relative to the class the dependencies relate to) - for class_dep in class_dependencies: - if class_dep in self.start_lines: - original_dependencies.append(class_dep) - else: - merged_dependencies.append(class_dep) - # Sort both list according to the order in their respective file - original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) - merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) - - # Add all original node first, then merged ones - for dep in original_dependencies + merged_dependencies: - remaining_dependencies.remove(dep) - relative_order[dep] = idx - idx += 1 - # Add the class itself - remaining_dependencies.remove(class_name) - relative_order[class_name] = idx - idx += 1 - - # Now add what still remains - remaining_dependencies = tuple(remaining_dependencies) - original_dependencies = [] - merged_dependencies = [] - for dep in remaining_dependencies: - if dep in self.modular_file_start_lines: - merged_dependencies.append(dep) - else: - original_dependencies.append(dep) - # Sort both list according to the order in their respective file - original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) - merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) - - # Add all original node first, then merged ones - for dep in original_dependencies + merged_dependencies: - relative_order[dep] = idx - idx += 1 - - return relative_order - - def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]): - """Update the global nodes and function dependency mapping with those from the modular file. - - Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies - instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). - """ - # Add/overwrite all needed function nodes and dependencies - self.functions.update(functions) - self.object_dependency_mapping.update( - {obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()} - ) - - def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): - """Update the global nodes with the assignment from the modular file. - - Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it is - in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the - big docstrings. - """ - for assignment, node in assignments.items(): - if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments: - self.assignments[assignment] = node - if assignment in object_mapping: - self.object_dependency_mapping[assignment] = object_mapping[assignment] - - def _merge_classes(self, classes: dict[str, cst.CSTNode]): - """Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and - are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined - classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we - do not add the new classes to `self.classes`, but only to `global_nodes`. - """ - # Add/overwrite all needed function nodes and dependencies - self.global_nodes.update( - { - name: node - for name, node in classes.items() - if name not in self.classes and name not in self.objects_imported_from_modeling - } - ) - - def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines): - """Merge classes, functions and assignments from the modular definitions into the current module file, - then record the relative order of all nodes. - Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the - merge with other files dependencies. - """ - self._merge_functions(functions, object_mapping) - self._merge_assignments(assignments, object_mapping) - self._merge_classes(classes) - self.modular_file_start_lines = start_lines - - # Correctly re-set the global nodes at this point - self.global_nodes.update(self.functions) - self.global_nodes.update(self.assignments) - # Create the global mapping of recursive dependencies for functions and assignments - self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() - - @classmethod - def visit_and_merge_dependencies( - cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines - ) -> "ModelFileMapper": - wrapper = MetadataWrapper(module) - mapper = cls(module) - wrapper.visit(mapper) - # Merge dependencies - mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines) - # Create the class dependencies graph - mapper.compute_class_dependencies() - return mapper - - -def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str): - """ - Replace a class node which inherits from another modeling class. This function works in the following way: - - start from the base class node of the inherited class (a cst.Node) - - replace all methods of the base node with the methods defined in the child class - - append all new methods defined in the child class - - replace all calls to super() with the unravelled code + Given the `class_name`, the `updated_node`'s call to super are unpacked. | ```python | | ```python | class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module): @@ -856,15 +547,14 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename | self.post_init() | ``` """ - all_bases = [k.value.value for k in class_node.bases] - - original_node = mapper.classes[renamed_super_class] + original_node = class_finder.classes[class_name] original_methods = { - f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f + f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f for f in original_node.body.body } updated_methods = { - f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body + f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f + for f in updated_node.body.body } end_meth = [] @@ -872,7 +562,7 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename docstring_node = [] # Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict for func in original_node.body.body: - name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) + name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func) if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None: new_params = updated_methods[name].params # Replace the method in the replacement class, preserving decorators @@ -883,23 +573,19 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename new_params = new_params.with_changes( params=list(parent_params.values()), star_kwarg=func.params.star_kwarg ) - # Keep decorators in `modular_xxx.py` if any, else original decorators - new_decorators = ( - updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators - ) if not re.match( r"\ndef .*\(.*\):\n raise.*Error\(.*", - mapper.python_module.code_for_node(updated_methods[name]), + class_finder.python_module.code_for_node(updated_methods[name]), ): - func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators) + func = func.with_changes(body=updated_methods[name].body, params=new_params) else: continue if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): - target = mapper.python_module.code_for_node(func.body[0].targets[0]) + target = class_finder.python_module.code_for_node(func.body[0].targets[0]) assign_targets[target] = func elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): - target = mapper.python_module.code_for_node(func.body[0].target) + target = class_finder.python_module.code_for_node(func.body[0].target) assign_targets[target] = func elif m.matches(func, DOCSTRING_NODE): docstring_node = [func] @@ -907,8 +593,8 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename end_meth.append(func) # Port new methods that are defined only in modular-file and append at the end - for func in class_node.body.body: - name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) + for func in updated_node.body.body: + name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func) if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class! # Extract the original docstring updated_docstring = func.body[0].value.value @@ -922,28 +608,22 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename end_meth.append(func) if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): # TODO we only use single assign might cause issues - target = mapper.python_module.code_for_node(func.body[0].targets[0]) + target = class_finder.python_module.code_for_node(func.body[0].targets[0]) assign_targets[target] = func if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): - target = mapper.python_module.code_for_node(func.body[0].target) + target = class_finder.python_module.code_for_node(func.body[0].target) assign_targets[target] = func end_meth = docstring_node + list(assign_targets.values()) + end_meth - # Replace the calls to `super()` with the unrolled code result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) temp_module = cst.Module(body=[result_node]) new_module = MetadataWrapper(temp_module) new_replacement_class = new_module.visit( - SuperTransformer(temp_module, original_methods, updated_methods, all_bases) + SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases) ) new_replacement_body = new_replacement_class.body[0].body # get the indented block - # Use decorators redefined in `modular_xxx.py` if any - new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators - # Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`) - name = class_node.name - - return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name) + return original_node.with_changes(body=new_replacement_body) TYPE_TO_FILE_TYPE = { @@ -952,483 +632,498 @@ def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, rename "Processor": "processing", "ImageProcessor": "image_processing", "FeatureExtractor": "feature_extractor", - "ProcessorKwargs": "processing", - "ImagesKwargs": "processing", - "TextKwargs": "processing", } -def find_file_type(class_name: str) -> str: - """Based on a class name, find the file type corresponding to the class. - If the class name is `LlamaConfig` it will return `configuration`. - The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling` +def get_new_part(class_name, base_class): """ - match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) - match = re.search(rf"({match_pattern})$", class_name) - if match: - file_type = TYPE_TO_FILE_TYPE[match.group(1)] + When `MyClassNameAttention` inherits from `MistralAttention`, we need + to process the name to properly find dependencies. + + Here we take what is the same (Attention) and what is different + when finding the dependencies. + """ + common_suffix_len = 0 + for i in range(1, min(len(class_name), len(base_class)) + 1): + if class_name[-i] == base_class[-i]: + common_suffix_len += 1 + else: + break + + if common_suffix_len > 0: + new_part = class_name[:-common_suffix_len] else: - file_type = "modeling" - return file_type + new_part = class_name + # Convert the remaining new part to snake_case + snake_case = re.sub(r"(? 0: - new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) - imports_to_keep.append(new_node) + def bar(): + foo1() + def foobar(): + bar() + foo2() -def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: - """Get all the imports needed in the `body`, from the list of `all_imports`. - `body` is a dict with the following structure `{str: {"insert_idx": int, "node": cst.CSTNode}}`. - Note: we need to use `isinstance` on scope assignements, m.matches apparently does not work here yet! + class MyLayer(SomeOtherModelLayer): + def forward(...): + foobar() + ``` + and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: + ``` + dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} + find_all_dependencies('foobar', dependency_mapping) + >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] + ``` + That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can + work correctly. """ - new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] - wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body)) - scopes = set(wrapper.resolve(ScopeProvider).values()) - unused_imports = set() - import_ref_count = {} - for scope in scopes: - for assignment in scope.assignments: - node = assignment.node - if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): - ref_count = len(assignment.references) - name = assignment.name - # Similar imports may be redefined, and only used between their 1st and 2nd definition - # so if we already have a ref count > 0, the imports is actually used - if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys(): - unused_imports.add(name) - import_ref_count[name] = ref_count - - imports_to_keep = [] - for node in all_imports: - if m.matches(node, m.If()): # handle safe imports - new_statements = [] - for stmt_node in node.body.body: - append_new_import_node(stmt_node, unused_imports, new_statements) - if len(new_statements) > 0: - new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) - imports_to_keep.append(new_node) - else: - append_new_import_node(node, unused_imports, imports_to_keep) - - protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())] - usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())] - # If the same import is both protected and unprotected, only keep the protected one - for protected_node in protected_import_nodes: - for stmt_node in protected_node.body.body: - usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]] - - # Protected imports always appear at the end of all imports - return usual_import_nodes + protected_import_nodes - - -def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]: - """Split the `__all__` assignment found in the modular between each corresponding files.""" - all_all_per_file = {} - assign_node = node.body[0] - if isinstance(assign_node.value, cst.List): - # Extract the elements from the list - all_all_to_add = defaultdict(list) - for element in assign_node.value.elements: - if isinstance(element.value, cst.SimpleString): - # Remove quotes and add the string to the elements list - class_name = element.value.value - file = find_file_type(element.value.evaluated_value) - all_all_to_add[file] += [class_name] - for file, new_alls in all_all_to_add.items(): - new_node = assign_node.with_changes( - value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) - ) - all_all_per_file[file] = node.with_changes(body=[new_node]) - return all_all_per_file + all_dependencies = deque(dependency_mapping[function]) + all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]] + checked_dependencies = set(function) + while len(all_dependencies) > 0: + # Pick element to visit + parent = all_dependencies.popleft() + if parent not in checked_dependencies: + # Update dependencies + all_dependencies.extend(dependency_mapping[parent]) + all_dependencies_with_parent += [(dependency, parent) for dependency in dependency_mapping[parent]] + # add visited node to the list + checked_dependencies.add(parent) + # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) + return all_dependencies_with_parent -class ModularFileMapper(ModuleMapper): - """This is a Mapper to visit a modular file (like `modular_llama.py`). It visits the whole file, recording dependency, - then visits all imported modeling files (like `modeling_llama.py`), and manages their mutual dependencies. - Calling the method `create_modules()` after visit will create all modules based on this modular file. - """ + +class PostModularConverterCleaner(CSTTransformer): + """Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due + to dependency mapping, even if code parts with those functions/classes were overwritten)""" + + METADATA_DEPENDENCIES = (ParentNodeProvider,) + + def __init__(self, added_dependencies: set): + super().__init__() + self.top_level_functions_or_classes = {} + self.all_used_functions_or_classes = set() + self.added_dependencies = added_dependencies + + def visit_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.top_level_functions_or_classes[node.name.value] = node + + def visit_ClassDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.top_level_functions_or_classes[node.name.value] = node + + def visit_Name(self, node: cst.Name): + """This is used to find any mention of a top-level function or class except its own definition. + It will contain other names as well, but those will not be used. This is the most general way to do it + since mentions may appear in a lot of different contexts (apart from simple Call to the function/class). + e.g. Attention classes are only mentionned by their name in a dict assignment. + """ + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + + if not ( + (m.matches(parent_node, m.ClassDef()) and parent_node.name.value == node.value) + or (m.matches(parent_node, m.FunctionDef()) and parent_node.name.value == node.value) + ): + self.all_used_functions_or_classes.add(node.value) + + def leave_Module(self, original_node: cst.Module, node): + # Find any class/function that was mistakenly added as part of the dependencies and remove it + unused = self.added_dependencies - self.all_used_functions_or_classes + nodes_to_remove = [ + self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes + ] + new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove] + # Return a new module with the updated body + return node.with_changes(body=new_body) + + +class ModularConverterTransformer(CSTTransformer): + METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None): - super().__init__(python_module) - # fmt: off - self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3` + super().__init__() + self.model_name = ( + new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3` + ) self.given_old_name = given_old_name self.given_new_name = given_new_name - - self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"} - self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module} - - self.all_all_to_add = {} + # fmt: off + self.python_module = python_module # we store the original module to use `code_for_node` + self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module + self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"} + self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama" + self.inserted_deps = [] # nodes inserted via super dependency + self.all_imports = [] # just stores all of the imports + self.all_safe_imports = [] # stores the import under simple statements + self.global_scope_index = 0 # fmt: on + self.files = { # mapping for different component bodies + "modeling": {}, + "configuration": {}, + "tokenization": {}, + "processing": {}, + "image_processing": {}, + "feature_extractor": {}, + } + self.match_patterns = "|".join(self.files.keys()) + self.all_definitions = {} + self.class_to_file_type = {} + self.current_class = None # keep track of current top-level class during visit + self.current_top_level_function = None # keep track of current top-level function during visit + # Mapping from top-level functions to classes using them + self.function_call_class_mapping = defaultdict(lambda: set()) + # Mapping from top-level functions to other top-level functions dependencies + self.function_call_dependency_mapping = defaultdict(lambda: set()) + self.added_dependencies = set() def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - """When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it, - and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. + """When visiting imports from `transformers.models.xxx` we need to: + 1. Get the original source code + 2. Parse it into an AST Tree + 3. Add this import to `self.transformers_imports` as visited to not parse it twice """ - import_module = self.python_module.code_for_node(node.module) - import_statement = "." * len(node.relative) + import_module - if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): - return + import_statement = self.python_module.code_for_node(node.module) if m.matches(node.module, m.Attribute()): for imported_ in node.names: - _import = re.search( - rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement - ) + _import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement) if _import: - source = _import.group(1) + source = _import.groups()[0] if source == "modeling" and "Config" in self.python_module.code_for_node(imported_): raise ValueError( f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead" ) - if import_module not in self.model_specific_modules: - if "models" not in import_module: - import_module = "models." + import_module - if "transformers" not in import_module: - import_module = "transformers." + import_module - source_code = get_module_source_from_name(import_module) + if import_statement not in self.transformers_imports: + if "models" not in import_statement: + import_statement = "models." + import_statement + if "transformers" not in import_statement: + import_statement = "transformers." + import_statement + source_code = get_module_source_from_name(import_statement) tree = cst.parse_module(source_code) - self.model_specific_modules[import_module] = tree - imported_object = self.python_module.code_for_node(imported_.name) - self.model_specific_imported_objects[imported_object] = import_module + self.transformers_imports[import_statement] = tree + imported_class = self.python_module.code_for_node(imported_.name) + self.imported_mapping[imported_class] = import_statement if m.matches(node.module, m.Name()): - if "transformers" == import_module: + if "transformers" == import_statement: raise ValueError( - f"You are importing from {import_module} directly using global imports. Import from the correct local path" + f"You are importing from {import_statement} directly using global imports. Import from the correct local path" ) - def visit_SimpleStatementLine(self, node): - """If we visit an import statement not previously visited, record it. If we visit a module-scope assignment, - simply record it or, if it is `__all__`, split it between files where we should dispatch it. - """ - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - simple_top_level_assign_structure = m.SimpleStatementLine( - body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] - ) + def leave_SimpleStatementLine(self, original_node, updated_node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) if m.matches(parent_node, m.Module()): - if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): - self.imports.append(node) - elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): - import_module = self.python_module.code_for_node(node.body[0].module) - import_statement = "." * len(node.body[0].relative) + import_module - if not ( - re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement) - and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR) - ): - self.imports.append(node) - elif m.matches(node, simple_top_level_assign_structure): - assigned_variable = node.body[0].targets[0].target.value - # __all__ is treated differently and not added to general assignments - if assigned_variable == "__all__": - self.all_all_to_add = split_all_assignment(node) - else: - self.assignments[assigned_variable] = node + if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])): + if updated_node not in self.all_imports: + self.all_imports.append(updated_node) + return updated_node + elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])): + full_statement = self.python_module.code_for_node(updated_node.body[0].module) + if re.search( + rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement + ): # OR MATCH ..llama.modeling_llama + return cst.RemoveFromParent() + if updated_node not in self.all_imports: + self.all_imports.append(updated_node) + return updated_node + elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])): + if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys(): + file_ = ASSIGNMENTS_TO_KEEP[original_node.body[0].targets[0].target.value] + self.files[file_][original_node.body[0].targets[0].target.value] = { + "node": original_node, + "insert_idx": self.global_scope_index, + } + self.global_scope_index += 100 + return updated_node - def leave_Module(self, node): - """When we leave the modular file, we do the following in order: - 1. compute the nested (recursive) function and assignment dependencies - 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update - its dependency graph with the new function and assignment definitions found in the modular - 3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) + def visit_ClassDef(self, node: cst.ClassDef): + """Used to keep track of current class""" + self.current_class = node.name.value + + def leave_ClassDef(self, original_node, updated_node): """ - # Takes care of finalizing our visit - super().leave_Module(node) - - # 1. compute the nested (recursive) function and assignment dependencies - self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() - - # 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies - self.visited_modules = {} - self.renamers = {} - for file, module in self.model_specific_modules.items(): - file_model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", file).groups()[0] - renamer = ReplaceNameTransformer( - file_model_name, self.model_name, self.given_old_name, self.given_new_name - ) - renamed_module = module.visit(renamer) - self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( - renamed_module, - self.classes, - self.functions, - self.assignments, - self.object_dependency_mapping, - self.start_lines, - ) - # We record it so that we can rename classes later the exact same way - self.renamers[file] = renamer - - # 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the - # definitions found in the visited files - self.merge_model_specific_imports(self.visited_modules) - - # We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later - # Note that we may visit several of the same file types, thus we save them per file type, not file - self.imported_objects_per_file = defaultdict(set) - for file, mapper in self.visited_modules.items(): - file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})_.*", file).group(1) - self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling) - - def merge_model_specific_imports(self, visited_modules): - """Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph, - based on the visited files.""" - self.start_lines_file_mapping = {} - self.added_objects_file_mapping = {} - for object_name, file in self.model_specific_imported_objects.items(): - visited_module = visited_modules[file] - self.start_lines_file_mapping[file] = visited_module.start_lines - # Add functions and their dependencies - if object_name in visited_module.functions and object_name not in self.functions: - self.functions[object_name] = visited_module.functions[object_name] - self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) - if dependencies is not None: - self.object_recursive_dependency_mapping[object_name] = dependencies - for dep in dependencies: - if dep not in self.global_nodes: - self.added_objects_file_mapping[dep] = file - self.functions[dep] = visited_module.global_nodes[dep] - - # Add assignments and their dependencies - elif object_name in visited_module.assignments and object_name not in self.assignments: - self.assignments[object_name] = visited_module.assignments[object_name] - self.added_objects_file_mapping[object_name] = file - dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) - if dependencies is not None: - self.object_recursive_dependency_mapping[object_name] = dependencies - for dep in dependencies: - if dep not in self.global_nodes: - self.added_objects_file_mapping[dep] = file - self.assignments[dep] = visited_module.global_nodes[dep] - - # Do not forget to re-assign all nodes after the merge - self.global_nodes = {**self.assignments, **self.classes, **self.functions} - - def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: - """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that - will be created based on the modular. + 1. Filter the `base` classes of this class + If they are from `transformers.models.xx` then: + - take the AST tree of the module it comes from and parse it with a `ClassFinder`. + - rename all every instance of `old_name` (llama) to `new_name` (gemma) + 2. We insert the modules which the inherited base depends on. This has to be done in + the order of the dependencies. If on is already in the new_body (because it's defined in the diff file) + then we remove it from the new body to add it again in the correct order. + 3. Replace the calls to `super().xxxx` merging parent code """ - relative_order = {} - idx = 0 - - original_dependencies = [] - other_files_dependencies = defaultdict(list) - for dep in tuple(missing_dependencies): - if dep in self.added_objects_file_mapping: - file = self.added_objects_file_mapping[dep] - other_files_dependencies[file].append(dep) - else: - original_dependencies.append(dep) - # Sort all lists according to the order in their respective file - all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) - for file, dependencies in other_files_dependencies.items(): - sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) - all_dependencies += sorted_dependencies - - # Add all original node first, then merged ones (one file at a time) - for dep in all_dependencies: - relative_order[dep] = idx - idx += 1 - - return relative_order - - -def check_dependencies_and_create_import_node( - file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str -) -> tuple[set[str], dict[str, cst.CSTNode]]: - """Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case, - we need to remove it from the dependencies, and create a new import to it instead. - This scenario may appear in the following case: - If a new class in the `modular_xxx.py` file does not belong to `type_xxx.py`, but is used somewhere in `other_type_xxx.py` - (e.g. as a type hint), but none of the visited files had a similar class, then it would be imported in `type_xxx.py` as - part of the standard dependency graph (because we never encountered an import towards this new class in any file). - For example imagine the following `modular.py`: - ``` - from ..llama.modeling_llama import LlamaModel - - class NewNameTextConfig(PretrainedConfig): - ... - - class NewNameConfig(PretrainedConfig): - ... - - class NewNameModel(LlamaModel): - config = NewNameConfig() - text_config = NewNameTextConfig() - ... - ``` - then without the help of this function, `NewNameTextConfig` would be imported in the `modeling_newname.py` as well as - `configuration_newname.py`, because `modeling_llama.py` tells us to not import `NewNameConfig`, but has no - knowledge of `NewNameTextConfig`. - """ - class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())} - corrected_dependencies = new_dependencies.copy() - new_imports = {} - for class_name in class_dependencies: - class_file_type = find_file_type(class_name) - # In this case, we need to remove it from the dependencies and create a new import instead - if class_file_type != file_type: - corrected_dependencies.remove(class_name) - import_statement = f"from .{class_file_type}_{new_name} import {class_name}" - new_imports[class_name] = cst.parse_statement(import_statement) - - return corrected_dependencies, new_imports - - -def get_class_node_and_dependencies( - modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] -) -> tuple[dict, str, dict]: - """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new - class node based on the inherited classes if needed. Also returns any new imports of a new class defined in - the modular that we nay need. - """ - bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects] - if len(bases) > 1: - raise ValueError( - f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}." - ) - - file_type = find_file_type(class_name) - file_to_update = files[file_type] - model_name = modular_mapper.model_name - - # This is used to avoid adding objects to the dependencies graph if they will be imported already - imported_objects = modular_mapper.imported_objects_per_file[file_type] - - # We need to replace the class node with the transformers (modeling file) super class node - if len(bases) == 1: - super_class = bases[0] - super_file_name = modular_mapper.model_specific_imported_objects[super_class] - - # Get the mapper corresponding to the inherited class - mapper = modular_mapper.visited_modules[super_file_name] - # Rename the super class according to the exact same rule we used when renaming the whole module - renamer = modular_mapper.renamers[super_file_name] - renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name) - renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name) - - # Create the new class node - updated_node = replace_class_node(mapper, node, renamed_super_class) - - # Grab all immediate dependencies of the new node - new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects) + class_name = original_node.name.value + bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping] + all_bases = [k.value.value for k in original_node.bases] + self.global_scope_index += 100 + for super_class in bases: + if super_class not in self.imported_mapping: + raise ImportError( + f"{super_class} was not imported using `from transformers.models.xxxxx.modeling_xxxx import {super_class}" + ) - # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove - # it from the dependencies, and add a new import of it instead - new_node_dependencies, new_imports = check_dependencies_and_create_import_node( - file_type, new_node_dependencies, mapper, model_name - ) + super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree + model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", super_file_name) + if model_name: + model_name = model_name.groups()[0] + else: + raise ValueError( + f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name" + ) + file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] + visited_module = self.visited_module + if super_file_name not in visited_module: # only extract classes once + class_finder = find_classes_in_file( + self.transformers_imports[super_file_name], + model_name, + self.model_name, + self.given_old_name, + self.given_new_name, + ) + visited_module[super_file_name] = class_finder + list_dependencies = { + dep: class_finder.class_start_line.get(dep, 1000) + for dep in class_finder.class_dependency_mapping.get(class_name, []) + } + else: # we are re-using the previously parsed data + class_finder = visited_module[super_file_name] + + list_dependencies = { + dep: class_finder.class_start_line.get(dep, 1000) + for dep in class_finder.class_dependency_mapping.get(class_name, []) + } + if len(list_dependencies) == 0: + # so, maybe standard renaming did not work (the class name is different) + # we try with another renaming pattern + potential_given_name = get_new_part(class_name, super_class) + del visited_module[super_file_name] + class_finder = find_classes_in_file( + self.transformers_imports[super_file_name], + model_name, + potential_given_name, + self.model_name, + potential_given_name, + ) + list_dependencies = { + dep: class_finder.class_start_line.get(dep, 1000) + for dep in class_finder.class_dependency_mapping.get(class_name, []) + } + if len(list_dependencies) == 0: + # last recourse, if the suffix of the new class is different from the one of the super class + # e.g. MyNewClassForSegmentation extends MyOldClassForObjectDetection + # we try with another renaming pattern + class_finder = find_classes_in_file( + self.transformers_imports[super_file_name], + model_name, + self.model_name, + self.given_old_name, + self.given_new_name, + super_class, + class_name, + ) + visited_module[super_file_name] = class_finder + list_dependencies = { + dep: class_finder.class_start_line.get(dep, 1000) + for dep in class_finder.class_dependency_mapping.get(class_name, []) + } + if len(list_dependencies) == 0: + raise ValueError( + f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})" + f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}." + f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`" + ) - # The node was modified -> look for all recursive dependencies of the new node - all_dependencies_to_add = find_all_dependencies( - dependency_mapping=mapper.class_dependency_mapping, - initial_dependencies=new_node_dependencies, - initial_checked_dependencies=set(file_to_update.keys()), - ) + list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) + start_insert_idx = self.global_scope_index + file_to_update = self.files[file_type] + is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n" + for dependency, _ in list_dependencies: + # we can write to the correct body, using the source of the parent class + node = class_finder.global_nodes.get(dependency, None) + if node is not None: + if dependency not in file_to_update: + node = self.all_definitions.pop(dependency, node) + start_insert_idx -= 1 + file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node} + self.added_dependencies.add(dependency) + elif dependency not in self.inserted_deps: + # make sure the node is written after its dependencies + start_insert_idx = file_to_update[dependency]["insert_idx"] - 1 + if ( + dependency in file_to_update.keys() + and dependency in class_finder.first_lvl_dependency_mapping[class_name] + ): + # If dependency is defined, but not used, raise error + calls = m.findall(original_node, m.Call(func=m.Name(dependency))) + if not calls and not is_empty_node and dependency not in all_bases: + raise ValueError( + f"""You defined `{dependency}` in the modular_{self.model_name}.py, it should be used + when you define `{class_name}`, as it is one of it's direct dependencies. Make sure + you use it in the `__init__` function.""" + ) + self.inserted_deps.append(dependency) + + if len(list_dependencies) > 0: + updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases) + + # Now, if a class was defined without parents, we look for the name + match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) + match = re.search(rf"({match_pattern})$", class_name) + if match: + key = TYPE_TO_FILE_TYPE[match.group(1)] + self.class_to_file_type[class_name] = key + self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} + else: + self.class_to_file_type[class_name] = "modeling" + self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} - relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) - nodes_to_add = { - dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add - } + self.current_class = None + return updated_node - # No transformers (modeling file) super class, just check functions and assignments dependencies - else: - updated_node = node - # The node was NOT modified -> no need to look recursively for other class dependencies. Indeed, even if they are not - # already defined (which would mean a weird order of the code in the modular...), they will be in the future - all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects) - - # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove - # it from the dependencies, and add a new import of it instead - all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node( - file_type, all_dependencies_to_add, modular_mapper, model_name - ) + def visit_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.current_top_level_function = node.name.value - relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) - nodes_to_add = { - dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep]) - for dep in all_dependencies_to_add - if dep not in file_to_update.keys() - } + def leave_FunctionDef(self, original_node, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + self.all_definitions[node.name.value] = node + return node + + def visit_Assign(self, node: cst.Assign) -> None: + # Check if the assignment target is '__all__' + if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__": + if isinstance(node.value, cst.List): + # Extract the elements from the list + all_all_to_add = defaultdict(list) + for elt in node.value.elements: + if isinstance(elt.value, cst.SimpleString): + # Remove quotes and add the string to the elements list + class_name = elt.value.value + file = self.class_to_file_type[ + elt.value.evaluated_value + ] # evaluated value give the content of the string + all_all_to_add[file] += [class_name] + for f_type, new_alls in all_all_to_add.items(): + updated_node = node.with_changes( + value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) + ) + self.files[f_type][class_name] = { + "insert_idx": self.global_scope_index + 100, + "node": updated_node, + } + + def leave_If(self, original_node, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + if m.matches(parent_node, m.Module()): + full_statement = self.python_module.code_for_node(original_node.test) + if re.search(r"[\s\S]*is_.*available", full_statement): + self.all_safe_imports.append(node) + elif full_statement not in self.all_imports: + logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}") + return node + + def visit_Call(self, node: cst.Call): + """This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them. + Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, + add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible.""" + # Only map function calls if we're inside a class (i.e., current_class is set) + if self.current_class is not None: + # Simple function calls such as foo() + if isinstance(node.func, cst.Name): + self.function_call_class_mapping[node.func.value].add(self.current_class) + elif self.current_top_level_function is not None: + # Simple function calls such as foo() + if isinstance(node.func, cst.Name): + self.function_call_dependency_mapping[self.current_top_level_function].add(node.func.value) + + def _maybe_add_function_to_body( + self, + top_level_function: str, + body: dict, + function_node: cst.FunctionDef, + matching_callers: Optional[set] = None, + parent: Optional[str] = None, + ) -> bool: + """Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers` + is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return + `True`. Return `False` otherwise. + """ + if matching_callers is None and parent is None: + raise ValueError("Cannot add function if both the parent and the matching callers are None.") + if matching_callers is None: + matching_callers = {parent} + if len(matching_callers) > 0 and top_level_function not in body.keys(): + # Add the function just before the first class using it + new_idx = min([body[element]["insert_idx"] for element in matching_callers]) + # Reorder the elements + for element in body.keys(): + if body[element]["insert_idx"] >= new_idx: + body[element]["insert_idx"] += 1 + # Assign new element to body (after changing the count to avoid messing it) + body[top_level_function] = {"insert_idx": new_idx, "node": function_node} + return True + return False + + def _recursively_add_all_new_needed_functions_in_files(self): + """For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in + the different files, and add them to the file if it is the case (also recursively adding all other functions that + may be needed in that function body).""" + # At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py` + for top_level_function, function_node in self.all_definitions.items(): + calling_entities = self.function_call_class_mapping[top_level_function] + # The function may be needed in different files, we need to iterate on them + for file, body in self.files.items(): + file_elements = set(body.keys()) + # If the intersection is not null, top_level_func must be added to file + matching_callers = calling_entities & file_elements + added = self._maybe_add_function_to_body(top_level_function, body, function_node, matching_callers) + # If the function was added, we need to recursively add all its dependencies + if added: + for dependency, parent in find_all_dependencies( + top_level_function, self.function_call_dependency_mapping + ): + self._maybe_add_function_to_body( + dependency, body, self.all_definitions[dependency], parent=parent + ) - # Add the class node itself to the nodes to add - class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 - nodes_to_add[class_name] = (class_idx, updated_node) - - return nodes_to_add, file_type, new_imports - - -def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: - """Create all the new modules based on visiting the modular file. It replaces all classes as necesary.""" - files = defaultdict(dict) - current_file_indices = defaultdict(lambda: 0) - - # For each class defined in modular, potentially replace the node and add it with its dependencies - for class_name, node in modular_mapper.classes.items(): - nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files) - - # Add the new potential new imports that we may need to the `modular_mapper` variable - modular_mapper.imported_objects_per_file[file_type].update(new_imports.keys()) - modular_mapper.imports.extend(list(new_imports.values())) - - # Sort the nodes according to their relative order - nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) - # Write all nodes to file - for dependency, (_, node) in nodes_to_add: - # This is used to keep certain variables at the beginning of the file - try: - # The -1000 is arbitrary -> just keep it bigger than the list - idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency) - except ValueError: - idx = current_file_indices[file_type] - current_file_indices[file_type] += 1 - files[file_type][dependency] = {"insert_idx": idx, "node": node} - - # Add the __all__ statement to files at the end - for file_type, node in modular_mapper.all_all_to_add.items(): - idx = current_file_indices[file_type] - files[file_type]["__all__"] = {"insert_idx": idx, "node": node} - - # Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves because - # they are wrapped in SimpleStatementLine or If which could have different newlines, blanks etc) - all_imports = modular_mapper.imports.copy() - all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports} - for file, mapper in modular_mapper.visited_modules.items(): - new_imports = [ - node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code - ] - new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} - all_imports.extend(new_imports) - all_imports_code.update(new_imports_code) + def leave_Module(self, original_node: cst.Module, node): + imports = {self.python_module.code_for_node(k): k for k in self.all_imports} + dependency_imports = {file_type: imports.copy() for file_type in self.files} + for super_file_name, visiter in self.visited_module.items(): + file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] + dependency_imports[file_type].update( + {self.python_module.code_for_node(k): k for k in visiter.imports.values()} + ) - # Find the correct imports, and write the new modules - for file, body in files.items(): - new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] - needed_imports = get_needed_imports(body, all_imports) - full_module = needed_imports + new_body - new_module = cst.Module(body=full_module, header=modular_mapper.python_module.header) - files[file] = new_module + # Check if any new top-level function from the `modular_xxx.py` should be added to the different files + # (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file). + self._recursively_add_all_new_needed_functions_in_files() - return files + for file, body in self.files.items(): + new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] + if len(new_body) > 0: + if file in dependency_imports.keys(): + new_body = list(dependency_imports[file].values()) + new_body + new_module = cst.Module(body=[*new_body], header=node.header) + # Final cleanup + new_module = MetadataWrapper(new_module).visit(PostModularConverterCleaner(self.added_dependencies)) + self.files[file] = new_module + return node def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None): @@ -1442,10 +1137,10 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, module = cst.parse_module(code) wrapper = MetadataWrapper(module) if cst_transformers is None: - cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name) + cst_transformers = ModularConverterTransformer(module, model_name, old_model_name, new_model_name) wrapper.visit(cst_transformers) - for file, module in create_modules(cst_transformers).items(): - if module != {}: + for file, node in cst_transformers.files.items(): + if node != {}: # Get relative path starting from src/transformers/ relative_path = re.search( r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/") @@ -1454,7 +1149,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, header = AUTO_GENERATED_MESSAGE.format( relative_path=relative_path, short_name=os.path.basename(relative_path) ) - ruffed_code = run_ruff(header + module.code, True) + ruffed_code = run_ruff(header + node.code, True) formatted_code = run_ruff(ruffed_code, False) output[file] = [formatted_code, ruffed_code] return output @@ -1485,7 +1180,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/gemma/modular_gemma.py"], + default=["src/transformers/models/roberta/modular_roberta.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) @@ -1502,7 +1197,6 @@ def save_modeling_file(modular_file, converted_file): args = parser.parse_args() if args.files_to_parse == ["all"]: args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) - args.files_to_parse += glob.glob("examples/**/modular_*.py", recursive=True) for file_name in find_priority_list(args.files_to_parse): print(f"Converting {file_name} to a single model single file format") diff --git a/utils/modular_model_converter.textClipping b/utils/modular_model_converter.textClipping new file mode 100644 index 0000000000000000000000000000000000000000..93a7b0661be5b2448eccefc075b9734d8e0b7aab GIT binary patch literal 259 zcmZ|EzY4-I5XbSWe=9B>L`t)dRz za9@6WQlsRF;`rvZ_PgySTqEJV-RbuFJ_}}C=MfsCL_`)dNm3W6!W?;M6v`qbaV8dw zjZ2l}k)z}C2PPkwFTNxRCb`a>Ld&WO#kej?VM$o_SCSygK|=-(6d+h&&}>m{2E4KY aN)VM${r%x+x;nVa73%Z6rZ9N*oyQYNeo5y5 literal 0 HcmV?d00001 From 4da8d5ff6dc5aed96db4d7ec6f2dbf4a3255688c Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 5 Nov 2024 19:21:29 +0000 Subject: [PATCH 08/73] added hybrid cache --- .../models/zamba2/modeling_zamba2.py | 273 ++++++++++-------- .../models/zamba2/modular_zamba2.py | 222 ++++++++++++-- 2 files changed, 352 insertions(+), 143 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index fa73fcf5905..85097b80a1f 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -60,18 +60,136 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined #### added from mamba_ssm.ops.triton.selective_state_update import selective_state_update else: - #### selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None - selective_state_update = None, None, None #### added + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: causal_conv1d_update, causal_conv1d_fn = None, None + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) #### added + + +class Zamba2DynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__( + self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None + ): + self.layers_block_type = config.layers_block_type + self.transformer_layers = [] + + self.has_previous_state = False + self.dtype = dtype + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * config.hidden_size) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + num_heads = self.intermediate_size // config.mamba_headdim + self.ssm_states = { + i: torch.zeros( + batch_size, num_heads, config.mamba_headdim, config.mamba_d_state, device=device, dtype=dtype + ) + for i in range(config.num_hidden_layers) + } + + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "hybrid": + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.update + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.reorder_cache + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.get_seq_length + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.to_legacy_cache + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") + + @classmethod + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.from_legacy_cache + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") + #### is_fast_path_available = all( #### (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) #### ) -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) #### added logger = logging.get_logger(__name__) @@ -157,92 +275,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(DynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_expand * config.hidden_size - ssm_state_size = config.mamba_d_state - conv_kernel_size = config.mamba_d_conv - self.n_mamba_heads = config.n_mamba_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - for i in range(config.num_hidden_layers): - self.conv_states += [ - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) - ] - cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) - self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - def layer_type_list(config: Zamba2Config): """ Returns list of layer ids containing hybrid layers @@ -376,7 +408,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -418,7 +450,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - layer_idx = self.layer_block_map[layer_idx] key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) # repeat k/v heads if n_kv_heads < n_heads @@ -478,7 +509,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -522,7 +553,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - layer_idx = self.layer_block_map[layer_idx] key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) # repeat k/v heads if n_kv_heads < n_heads @@ -593,7 +623,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -814,7 +844,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_params: Optional[Zamba2DynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -825,7 +855,7 @@ def cuda_kernels_forward( d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # getting projected states from cache if it exists - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and cache_params.has_previous_state: in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] @@ -909,6 +939,11 @@ def cuda_kernels_forward( ) # 1D Convolution + hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2) + conv_state = nn.functional.pad( + hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] @@ -953,7 +988,7 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + def torch_forward(self, input_states, cache_params: Optional[Zamba2DynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection @@ -967,7 +1002,7 @@ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentio if cache_params is not None: ssm_state = cache_params.ssm_states[self.layer_idx].clone() ssm_state = ssm_state.to(hidden_states.device) - if cache_params.seqlen_offset > 0: + if cache_params.has_previous_state: conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] conv_state = torch.roll(conv_state, shifts=-1, dims=-1) # handle batched generation - states are copied through @@ -997,7 +1032,7 @@ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentio hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and cache_params.has_previous_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -1098,7 +1133,7 @@ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentio B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] # permute back B * decay states states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and cache_params.has_previous_state: previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] else: previous_states = torch.zeros_like(states[:, :1]) @@ -1143,11 +1178,11 @@ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentio def forward( self, hidden_states, - cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_params: Optional[Zamba2DynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): - if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + if not is_fast_path_available and "cuda" in self.in_proj.weight.device.type: return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) dtype = hidden_states.dtype if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: @@ -1246,7 +1281,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1261,7 +1296,7 @@ def forward( (see fig. 2 in https://arxiv.org/pdf/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1314,7 +1349,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1325,7 +1360,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1382,7 +1417,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1395,7 +1430,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1471,7 +1506,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = False - _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _supports_cache_class = True # Note: only supports Zamba2DynamicCache _is_stateful = True def _init_weights(self, module): @@ -1532,7 +1567,7 @@ def _check_and_enable_flash_attn_2( Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` as we do not want to disable Flash Attention 2 in Zamba2. """ - return super(ZambaPreTrainedModel)._check_and_enable_flash_attn_2( + return PreTrainedModel._check_and_enable_flash_attn_2( config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map ) @@ -1577,14 +1612,14 @@ def _check_and_enable_flash_attn_2( config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`Zamba2DynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A Zamba2DynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `HybridMambaAttentionDynamicCache` class for more details. + See the `Zamba2DynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1684,7 +1719,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[Zamba2DynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1721,7 +1756,7 @@ def forward( if use_cache and past_key_values is None: logger.warning_once( - "Zamba2 requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "Zamba2 requires an initialized `Zamba2DynamicCache` to return a cache. None was " "provided, so no cache will be returned." ) @@ -1867,7 +1902,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[Zamba2DynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1961,7 +1996,7 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + # Overwitten -- has a unique cache type, `Zamba2DynamicCache` empty_past_kv = past_key_values is None @@ -1975,9 +2010,7 @@ def prepare_inputs_for_generation( elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - past_key_values = HybridMambaAttentionDynamicCache( - self.config, input_ids.shape[0], dtype=self.dtype, device=self.device - ) + past_key_values = Zamba2DynamicCache(self.config, input_ids.shape[0], dtype=self.dtype, device=self.device) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 5398fa2d7b4..b43762f206f 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -54,7 +54,7 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined else: - selective_state_update = None, None, None + selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -62,6 +62,7 @@ causal_conv1d_update, causal_conv1d_fn = None, None is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + from ..zamba.modeling_zamba import ( ZambaMambaDecoderLayer, ZambaAttention, @@ -70,7 +71,6 @@ ZambaModel, ZambaPreTrainedModel, ZambaRMSNorm, - HybridMambaAttentionDynamicCache, repeat_kv, ) from ...configuration_utils import PretrainedConfig @@ -381,6 +381,123 @@ def segment_sum(input_tensor): return tensor_segsum +class Zamba2DynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__( + self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None + ): + self.layers_block_type = config.layers_block_type + self.transformer_layers = [] + + self.has_previous_state = False + self.dtype = dtype + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = int(config.mamba_expand * config.hidden_size) + + self.conv_states = { + i: torch.zeros( + batch_size, + self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, + self.conv_kernel_size, + device=device, + dtype=dtype, + ) + for i in range(config.num_hidden_layers) + } + num_heads = self.intermediate_size // config.mamba_headdim + self.ssm_states = { + i: torch.zeros( + batch_size, num_heads, config.mamba_headdim, config.mamba_d_state, device=device, dtype=dtype + ) + for i in range(config.num_hidden_layers) + } + + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "hybrid": + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) + + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.update + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.reorder_cache + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.get_seq_length + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.to_legacy_cache + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") + + @classmethod + # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.from_legacy_cache + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") + + # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Zamba class Zamba2RMSNorm(ZambaRMSNorm): pass @@ -536,7 +653,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -578,7 +695,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - layer_idx = self.layer_block_map[layer_idx] key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) # repeat k/v heads if n_kv_heads < n_heads @@ -638,7 +754,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -682,7 +798,6 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - layer_idx = self.layer_block_map[layer_idx] key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) # repeat k/v heads if n_kv_heads < n_heads @@ -753,7 +868,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -925,7 +1040,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_params: Optional[Zamba2DynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -936,7 +1051,7 @@ def cuda_kernels_forward( d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads # getting projected states from cache if it exists - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and cache_params.has_previous_state: in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D) d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2 split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads] @@ -1020,6 +1135,12 @@ def cuda_kernels_forward( ) # 1D Convolution + hidden_states_B_C_t = hidden_states_B_C.transpose(1,2) + conv_state = nn.functional.pad( + hidden_states_B_C_t, + (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] @@ -1064,7 +1185,7 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + def torch_forward(self, input_states, cache_params: Optional[Zamba2DynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection @@ -1078,7 +1199,7 @@ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentio if cache_params is not None: ssm_state = cache_params.ssm_states[self.layer_idx].clone() ssm_state = ssm_state.to(hidden_states.device) - if cache_params.seqlen_offset > 0: + if cache_params.has_previous_state: conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] conv_state = torch.roll(conv_state, shifts=-1, dims=-1) # handle batched generation - states are copied through @@ -1108,7 +1229,7 @@ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentio hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2)) hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1) A = -torch.exp(self.A_log.float()) # [num_heads] - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and cache_params.has_previous_state: # Note: there is no need to pad parameter matrices here, as there is just one new token # for batched generation dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...] @@ -1209,7 +1330,7 @@ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentio B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] # permute back B * decay states states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) - if cache_params is not None and cache_params.seqlen_offset > 0: + if cache_params is not None and cache_params.has_previous_state: previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] else: previous_states = torch.zeros_like(states[:, :1]) @@ -1254,7 +1375,7 @@ def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentio def forward( self, hidden_states, - cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_params: Optional[Zamba2DynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -1335,7 +1456,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1350,7 +1471,7 @@ def forward( (see fig. 2 in https://arxiv.org/pdf/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1410,7 +1531,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1423,7 +1544,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1522,7 +1643,6 @@ def _init_weights(self, module): module.dt_bias.copy_(inv_dt) module.dt_bias._no_reinit = True - @classmethod @classmethod def _check_and_enable_flash_attn_2( cls, @@ -1536,7 +1656,7 @@ def _check_and_enable_flash_attn_2( Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` as we do not want to disable Flash Attention 2 in Zamba2. """ - return super(ZambaPreTrainedModel)._check_and_enable_flash_attn_2( + return PreTrainedModel._check_and_enable_flash_attn_2( config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map ) @@ -1576,14 +1696,14 @@ def _check_and_enable_flash_attn_2( config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`Zamba2DynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A Zamba2DynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `HybridMambaAttentionDynamicCache` class for more details. + See the `Zamba2DynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1681,6 +1801,62 @@ def __init__(self, config: Zamba2Config): # Initialize weights and apply final processing self.post_init() + + # Adapted from transformers.models.zamba.modeling_zamba.ZambaForCausalLM.prepare_inputs_for_generation + # with `Zamba2DynamicCache` -> `Zamba2DynamicCache` + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Overwitten -- has a unique cache type, `Zamba2DynamicCache` + + empty_past_kv = past_key_values is None + + # Omit tokens covered by past_key_values + if not empty_past_kv: + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = Zamba2DynamicCache( + self.config, input_ids.shape[0], dtype=self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "num_logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs @add_start_docstrings( From 6b5a9be2647a413e758b63131cd6ecf2f1b387e2 Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 5 Nov 2024 20:30:46 +0000 Subject: [PATCH 09/73] fix attention_mask in mamba --- .../models/zamba2/modeling_zamba2.py | 31 ++++++------------- .../models/zamba2/modular_zamba2.py | 12 ++++--- 2 files changed, 17 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 85097b80a1f..e58bc66707e 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -900,7 +900,7 @@ def cuda_kernels_forward( out = self.out_proj(hidden_states)[:, None, ...] # if no cache is found, calling the kernel else: - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + if attention_mask is not None and not torch.all(attention_mask == 1): # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 dtype = hidden_states.dtype hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) @@ -908,8 +908,12 @@ def cuda_kernels_forward( projected_states = self.in_proj(hidden_states) A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + if attention_mask is not None: + input_not_masked = torch.all(attention_mask == 1) + else: + input_not_masked = True - if self.training and cache_params is None: + if self.training and cache_params is None and input_not_masked: out, ssm_state = mamba_split_conv1d_scan_combined( projected_states, self.conv1d.weight.squeeze(1), @@ -960,7 +964,7 @@ def cuda_kernels_forward( [self.intermediate_size, groups_time_state_size, groups_time_state_size], dim=-1, ) - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + if attention_mask is not None and not torch.all(attention_mask == 1): # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 dtype = hidden_states.dtype hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) @@ -1020,7 +1024,7 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2DynamicCache] ) cache_params.conv_states[self.layer_idx].copy_(conv_state) hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + if attention_mask is not None and not torch.all(attention_mask==1): dtype = hidden_states.dtype # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) @@ -1182,7 +1186,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): - if not is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) dtype = hidden_states.dtype if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: @@ -1549,21 +1553,6 @@ def _check_and_enable_flash_attn_2( Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 2 by default on Zamba2 models. Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba2 v1. - Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` - as we do not want to disable Flash Attention 2 in Zamba2. - """ - """ - Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 2 by default on Zamba2 models. - Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba2 v1. - """ - config = super()._check_and_enable_flash_attn_2( - config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map - ) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "flash_attention_2": - config._attn_implementation = "eager" - """ Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` as we do not want to disable Flash Attention 2 in Zamba2. """ @@ -1571,8 +1560,6 @@ def _check_and_enable_flash_attn_2( config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map ) - return config - _CONFIG_FOR_DOC = "Zamba2Config" diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index b43762f206f..0028d8ce6c1 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1096,7 +1096,7 @@ def cuda_kernels_forward( out = self.out_proj(hidden_states)[:, None, ...] # if no cache is found, calling the kernel else: - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + if attention_mask is not None and not torch.all(attention_mask==1): # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 dtype = hidden_states.dtype hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) @@ -1104,8 +1104,12 @@ def cuda_kernels_forward( projected_states = self.in_proj(hidden_states) A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + if attention_mask is not None: + input_not_masked = torch.all(attention_mask==1) + else: + input_not_masked = True - if self.training and cache_params is None: + if self.training and cache_params is None and input_not_masked: out, ssm_state = mamba_split_conv1d_scan_combined( projected_states, self.conv1d.weight.squeeze(1), @@ -1157,7 +1161,7 @@ def cuda_kernels_forward( [self.intermediate_size, groups_time_state_size, groups_time_state_size], dim=-1, ) - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + if attention_mask is not None and not torch.all(attention_mask==1): # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 dtype = hidden_states.dtype hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) @@ -1217,7 +1221,7 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2DynamicCache] ) cache_params.conv_states[self.layer_idx].copy_(conv_state) hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len] - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + if attention_mask is not None and not torch.all(attention_mask==1): dtype = hidden_states.dtype # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) From 248350d6331f7e1990406e37c05a99449253627f Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 5 Nov 2024 22:42:17 +0000 Subject: [PATCH 10/73] dropped unused loras --- .../models/zamba2/modeling_zamba2.py | 93 ++++++++++++------- .../models/zamba2/modular_zamba2.py | 80 +++++++++------- 2 files changed, 105 insertions(+), 68 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index e58bc66707e..da19ec04a26 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -187,9 +187,9 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") -#### is_fast_path_available = all( -#### (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -#### ) +# is_fast_path_available = all( +# (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +# ) logger = logging.get_logger(__name__) @@ -336,9 +336,23 @@ class Zamba2Attention(nn.Module): Additionally, replaced attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) + + Multi-headed attention from 'Attention Is All You Need' paper. + + Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: + The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads. + The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + Additionally, replaced + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) + Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this + layer is tied, un-tied LoRA modules are added to the q, k, v projectors to increase expressivity with a small memory overhead. """ - def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_mem_blocks=None): + def __init__( + self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks=None, block_id: int = None + ): super().__init__() self.config = config self.layer_idx = layer_idx @@ -361,33 +375,38 @@ def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_me self.q_proj = nn.Linear(self.attention_hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.num_mem_blocks = num_mem_blocks + self.num_fwd_mem_blocks = num_fwd_mem_blocks self.rope_theta = config.rope_theta self.layer_block_map = layer_type_list(config) + self.block_id = block_id - ### add to config: - # config.attention_hidden_size - # config.attention_head_dim - # config.max_position_embeddings if config.use_shared_attention_lora: - self.linear_q_lora_A_list = nn.ParameterList([]) - self.linear_q_lora_B_list = nn.ParameterList([]) - self.linear_k_lora_A_list = nn.ParameterList([]) - self.linear_k_lora_B_list = nn.ParameterList([]) - self.linear_v_lora_A_list = nn.ParameterList([]) - self.linear_v_lora_B_list = nn.ParameterList([]) - - for i in range(self.num_mem_blocks): - linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) - linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + self.linear_q_lora_A_list = nn.ModuleList([]) + self.linear_q_lora_B_list = nn.ModuleList([]) + self.linear_k_lora_A_list = nn.ModuleList([]) + self.linear_k_lora_B_list = nn.ModuleList([]) + self.linear_v_lora_A_list = nn.ModuleList([]) + self.linear_v_lora_B_list = nn.ModuleList([]) + + for i in range(self.num_fwd_mem_blocks): + if i % config.num_mem_blocks == block_id: + linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) + linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) + linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) + linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + else: + linear_q_lora_A = nn.Identity() + linear_q_lora_B = nn.Identity() + linear_k_lora_A = nn.Identity() + linear_k_lora_B = nn.Identity() + linear_v_lora_A = nn.Identity() + linear_v_lora_B = nn.Identity() self.linear_q_lora_A_list.append(linear_q_lora_A) self.linear_q_lora_B_list.append(linear_q_lora_B) - linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) - linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) self.linear_k_lora_A_list.append(linear_k_lora_A) self.linear_k_lora_B_list.append(linear_k_lora_B) - linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) - linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) self.linear_v_lora_A_list.append(linear_v_lora_A) self.linear_v_lora_B_list.append(linear_v_lora_B) @@ -1204,15 +1223,16 @@ def forward( class Zamba2MLP(nn.Module): - def __init__(self, config: Zamba2Config, num_mem_blocks=None): + def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int = None): """ - Shared MLP layer. To the intermediate activations of the MLP, we add un-shared LoRA's, which - introduce some amount of diversification across the shared MLP layers. + This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer + is tied, un-tied LoRA modules are added to the up and gate projectors to increase expressivity with a small memory overhead. """ super().__init__() self.config = config self.hidden_size = config.hidden_size - self.num_mem_blocks = num_mem_blocks + self.num_fwd_mem_blocks = num_fwd_mem_blocks + self.block_id = block_id self.ffn_intermediate_size = config.intermediate_size self.act_fn = ACT2FN[config.hidden_act] @@ -1230,9 +1250,13 @@ def gated_act_fn(x): if self.config.use_shared_block_lora: self.gate_up_proj_lora_A_list = nn.ModuleList([]) self.gate_up_proj_lora_B_list = nn.ModuleList([]) - for i in range(self.num_mem_blocks): - gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias=False) - gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias=False) + for i in range(self.num_fwd_mem_blocks): + if i % config.num_mem_blocks == block_id: + gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias=False) + gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias=False) + else: + gate_up_proj_lora_A = nn.Identity() + gate_up_proj_lora_B = nn.Identity() self.gate_up_proj_lora_A_list.append(gate_up_proj_lora_A) self.gate_up_proj_lora_B_list.append(gate_up_proj_lora_B) @@ -1268,13 +1292,14 @@ def count_mem_blocks_in_config(config: Zamba2Config): class Zamba2AttentionDecoderLayer(nn.Module): - def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None): + def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None): super().__init__() num_gs = count_mem_blocks_in_config(config) + self.block_id = block_id self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation]( - config, layer_idx=-1, num_mem_blocks=num_gs + config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id ) - self.feed_forward = Zamba2MLP(config, num_mem_blocks=num_gs) + self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1651,7 +1676,7 @@ def __init__(self, config: Zamba2Config): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - blocks = [Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)] + blocks = [Zamba2AttentionDecoderLayer(config, block_id=k) for k in range(config.num_mem_blocks)] mamba_layers = [] linear_layers = [] self.layers_block_type = config.layers_block_type diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 0028d8ce6c1..9651e39ce30 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -593,8 +593,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class Zamba2Attention(ZambaAttention): """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". + Multi-headed attention from 'Attention Is All You Need' paper. Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads. @@ -603,37 +602,44 @@ class Zamba2Attention(ZambaAttention): Additionally, replaced attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) + Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this + layer is tied, un-tied LoRA modules are added to the q, k, v projectors to increase expressivity with a small memory overhead. """ - def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_mem_blocks = None): + def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks = None, block_id: int = None): super().__init__(config, layer_idx) - self.num_mem_blocks = num_mem_blocks + self.num_fwd_mem_blocks = num_fwd_mem_blocks self.rope_theta = config.rope_theta self.layer_block_map = layer_type_list(config) - - ### add to config: - # config.attention_hidden_size - # config.attention_head_dim - # config.max_position_embeddings + self.block_id = block_id + if config.use_shared_attention_lora: - self.linear_q_lora_A_list = nn.ParameterList([]) - self.linear_q_lora_B_list = nn.ParameterList([]) - self.linear_k_lora_A_list = nn.ParameterList([]) - self.linear_k_lora_B_list = nn.ParameterList([]) - self.linear_v_lora_A_list = nn.ParameterList([]) - self.linear_v_lora_B_list = nn.ParameterList([]) + self.linear_q_lora_A_list = nn.ModuleList([]) + self.linear_q_lora_B_list = nn.ModuleList([]) + self.linear_k_lora_A_list = nn.ModuleList([]) + self.linear_k_lora_B_list = nn.ModuleList([]) + self.linear_v_lora_A_list = nn.ModuleList([]) + self.linear_v_lora_B_list = nn.ModuleList([]) - for i in range(self.num_mem_blocks): - linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) - linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + for i in range(self.num_fwd_mem_blocks): + if i % config.num_mem_blocks == block_id: + linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) + linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) + linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) + linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + else: + linear_q_lora_A = nn.Identity() + linear_q_lora_B = nn.Identity() + linear_k_lora_A = nn.Identity() + linear_k_lora_B = nn.Identity() + linear_v_lora_A = nn.Identity() + linear_v_lora_B = nn.Identity() self.linear_q_lora_A_list.append(linear_q_lora_A) self.linear_q_lora_B_list.append(linear_q_lora_B) - linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) - linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) self.linear_k_lora_A_list.append(linear_k_lora_A) self.linear_k_lora_B_list.append(linear_k_lora_B) - linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) - linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) self.linear_v_lora_A_list.append(linear_v_lora_A) self.linear_v_lora_B_list.append(linear_v_lora_B) @@ -1394,15 +1400,16 @@ def forward( class Zamba2MLP(nn.Module): - def __init__(self, config: Zamba2Config, num_mem_blocks = None): + def __init__(self, config: Zamba2Config, num_fwd_mem_blocks = None, block_id: int = None): """ - Shared MLP layer. To the intermediate activations of the MLP, we add un-shared LoRA's, which - introduce some amount of diversification across the shared MLP layers. + This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer + is tied, un-tied LoRA modules are added to the up and gate projectors to increase expressivity with a small memory overhead. """ super().__init__() self.config = config self.hidden_size = config.hidden_size - self.num_mem_blocks = num_mem_blocks + self.num_fwd_mem_blocks = num_fwd_mem_blocks + self.block_id = block_id self.ffn_intermediate_size = config.intermediate_size self.act_fn = ACT2FN[config.hidden_act] @@ -1411,16 +1418,20 @@ def gated_act_fn(x): return self.act_fn(x[0]) * x[1] self.gated_act_fn = gated_act_fn - + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias = config.add_bias_linear) self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) if self.config.use_shared_block_lora: self.gate_up_proj_lora_A_list = nn.ModuleList([]) self.gate_up_proj_lora_B_list = nn.ModuleList([]) - for i in range(self.num_mem_blocks): - gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) - gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias = False) + for i in range(self.num_fwd_mem_blocks): + if i % config.num_mem_blocks == block_id: + gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) + gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias = False) + else: + gate_up_proj_lora_A = nn.Identity() + gate_up_proj_lora_B = nn.Identity() self.gate_up_proj_lora_A_list.append(gate_up_proj_lora_A) self.gate_up_proj_lora_B_list.append(gate_up_proj_lora_B) @@ -1445,11 +1456,12 @@ def forward(self, hidden_state, layer_idx = None): class Zamba2AttentionDecoderLayer(nn.Module): - def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None): + def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None): super().__init__() num_gs = count_mem_blocks_in_config(config) - self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=-1, num_mem_blocks = num_gs) - self.feed_forward = Zamba2MLP(config, num_mem_blocks = num_gs) + self.block_id = block_id + self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id) + self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1752,7 +1764,7 @@ def __init__(self, config: Zamba2Config): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - blocks = [Zamba2AttentionDecoderLayer(config) for _ in range(config.num_mem_blocks)] + blocks = [Zamba2AttentionDecoderLayer(config, block_id=k) for k in range(config.num_mem_blocks)] mamba_layers = [] linear_layers = [] self.layers_block_type = config.layers_block_type From d1d2c668f1086fb377c252d59be35abddff90263 Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 5 Nov 2024 23:08:40 +0000 Subject: [PATCH 11/73] fix flash2 --- .../models/zamba2/modeling_zamba2.py | 28 ++--------------- .../models/zamba2/modular_zamba2.py | 31 +++++++------------ 2 files changed, 13 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index da19ec04a26..0d980e6203d 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -187,11 +187,6 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") -# is_fast_path_available = all( -# (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -# ) - - logger = logging.get_logger(__name__) @@ -521,6 +516,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.is_causal = True def forward( self, @@ -614,6 +610,7 @@ def forward( value_states, attention_mask, q_len, + is_causal = self.is_causal, dropout=dropout_rate, softmax_scale=softmax_scale, ) @@ -1564,27 +1561,6 @@ def _init_weights(self, module): module.dt_bias.copy_(inv_dt) module.dt_bias._no_reinit = True - @classmethod - @classmethod - def _check_and_enable_flash_attn_2( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, Dict[str, int]]] = None, - hard_check_only: bool = False, - check_device_map: bool = False, - ): - """ - Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 2 by default on Zamba2 models. - Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba2 v1. - - Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` - as we do not want to disable Flash Attention 2 in Zamba2. - """ - return PreTrainedModel._check_and_enable_flash_attn_2( - config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map - ) - _CONFIG_FOR_DOC = "Zamba2Config" diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 9651e39ce30..ecd78a7b071 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -612,6 +612,7 @@ def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fw self.rope_theta = config.rope_theta self.layer_block_map = layer_type_list(config) self.block_id = block_id + self.is_causal = True if config.use_shared_attention_lora: self.linear_q_lora_A_list = nn.ModuleList([]) @@ -848,6 +849,7 @@ def forward( q_len, dropout=dropout_rate, softmax_scale=softmax_scale, + is_causal=self.is_causal, ) attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() @@ -1627,10 +1629,16 @@ def forward( "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", ZAMBA2_START_DOCSTRING, ) -class Zamba2PreTrainedModel(ZambaPreTrainedModel): +class Zamba2PreTrainedModel(PreTrainedModel): + config_class = Zamba2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - # Leaving this commented out for now until testing - # _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] + _supports_sdpa = False + _supports_cache_class = True # Note: only supports Zamba2DynamicCache + _is_stateful = True def _init_weights(self, module): std = self.config.initializer_range @@ -1659,23 +1667,6 @@ def _init_weights(self, module): module.dt_bias.copy_(inv_dt) module.dt_bias._no_reinit = True - @classmethod - def _check_and_enable_flash_attn_2( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, Dict[str, int]]] = None, - hard_check_only: bool = False, - check_device_map: bool = False, - ): - """ - Replaces `ZambaPreTrainedModel._check_and_enable_flash_attn_2` with `PreTrainedModel._check_and_enable_flash_attn_2` - as we do not want to disable Flash Attention 2 in Zamba2. - """ - return PreTrainedModel._check_and_enable_flash_attn_2( - config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map - ) - ZAMBA2_INPUTS_DOCSTRING = r""" Args: From 5f5d01ea78714a4b88ab49ac5241c98e482b4af3 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 6 Nov 2024 08:51:20 +0000 Subject: [PATCH 12/73] config docstrings --- .../models/zamba2/modular_zamba2.py | 99 ++++++++++--------- 1 file changed, 52 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index ecd78a7b071..e1678ea0740 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -95,15 +95,48 @@ class Zamba2Config(PretrainedConfig): vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Zamba2Model`] + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. hidden_size (`int`, *optional*, defaults to 2560): Dimension of the hidden representations. - ffn_hidden_size (`int`, *optional*, defaults to 4 * hidden_size): - Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 54): Number of hidden layers in the model. + mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents. + mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel. + mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + mamba_headdim (`int`, *optional*, defaults to 64): + Dimension of each mamba head. + mamba_ngroups (`int`, *optional*, defaults to 8): + Number of groups for the evolution matrices of mamba 2. + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`): + Accepted range of time step values. + mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + n_mamba_heads (`int`, *optional*, defaults to 1): + Number of heads for the evolution matrices of mamba 2. + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + hidden_mamba_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) adjacent to the mamba conv. + chunk_size (`int`, *optional*, defaults to 256): + Size of the chunks that will comprise the sequence. + add_bias_linear (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in various layers + intermediate_size (`int`, *optional*, defaults to 4 * hidden_size): + Dimension of the MLP representations. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the MLP. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*): @@ -113,11 +146,23 @@ class Zamba2Config(PretrainedConfig): converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). - mamba_headdim (``, *optional*, defaults to 64): - dimension of each Mamba2 heads (number of heads is set to 1 in this implementation). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_mem_blocks (`int`, *optional*, defaults to 1): + Number of unshared transformer blocks. + use_shared_block_lora (`bool`, *optional*, defaults to `True`): + If True, unshared LoRA's will be added to the shared MLP's. + use_shared_attention_lora (`bool`, *optional*, defaults to `False`): + If True, unshared LoRA's will be added to the q, k, v projectors in the shared attention layers. + lora_rank (`int`, *optional*, defaults to 128): + Rank of the LoRA in the shared MLP and shared attention layers. + use_mem_rope (`bool`, *optional*, defaults to `False`): + If True, includes RoPE in the shared attention layers. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): + rms_norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only @@ -134,29 +179,6 @@ class Zamba2Config(PretrainedConfig): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 2): The id of the "end-of-sequence" token. - sliding_window (`int`, *optional*): - Sliding window attention window size. If not specified, will default to `None`. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - use_mamba_kernels (`bool`, *optional*, defaults to `True`): - Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and - `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if - `True` and kernels are not available - state_size (`int`, *optional*, defaults to 16): - The dimension the mamba state space latents - mamba_d_conv (`int`, *optional*, defaults to 4): - The size of the mamba convolution kernel - mamba_expand (`int`, *optional*, defaults to 2): - Expanding factor (relative to hidden_size) used to determine the mamba intermediate size - add_bias_linear (`bool`, *optional*, defaults to `False`): - Flag indicating whether or not to use bias in various layers - gated_linear_units (`bool`, *optional*, defaults to `False`): - Flag indicating whether or not to use gated MLP - use_shared_block_lora (`bool`, *optional*, defaults to `False`): - Flag indicating whether or not to add (unshared) LoRA modules to the first layer of the MLP - inside the shared transformer blocks - state_size (`int`, *optional*, defaults to 128): - The rank of the LoRA modules inside the MLP of the shared transformer blocks """ model_type = "zamba2" @@ -182,31 +204,24 @@ def __init__( mamba_dt_rank="auto", n_mamba_heads=1, - mamba_conv_bias=True, mamba_proj_bias=False, hidden_mamba_act="silu", - use_mamba_kernels=True, use_conv_bias=True, chunk_size=256, add_bias_linear=False, intermediate_size=None, - gated_linear_unit=True, hidden_act="gelu", num_attention_heads=32, num_key_value_heads=None, - sliding_window=None, attention_dropout=0.0, num_mem_blocks=1, use_shared_block_lora=True, use_shared_attention_lora=False, lora_rank=128, - use_mem_eff_path=True, use_mem_rope=False, rope_theta=10000, - attention_hidden_size=None, - attention_head_dim=None, initializer_range=0.02, rms_norm_eps=1e-5, @@ -238,18 +253,11 @@ def __init__( self.hidden_act = hidden_act self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads - self.sliding_window = sliding_window self.num_mem_blocks = num_mem_blocks self.use_mem_rope = use_mem_rope self.rope_theta = rope_theta - if attention_hidden_size is None: - self.attention_hidden_size = 2 * hidden_size - else: - self.attention_hidden_size = attention_hidden_size - if attention_head_dim is None: - self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads - else: - self.attention_head_dim = attention_head_dim + self.attention_hidden_size = 2 * hidden_size + self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads self.attention_dropout = attention_dropout self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv @@ -262,12 +270,10 @@ def __init__( self.mamba_conv_bias = mamba_conv_bias self.mamba_proj_bias = mamba_proj_bias self.hidden_mamba_act = hidden_mamba_act - self.use_mamba_kernels = use_mamba_kernels self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.gated_linear_unit = gated_linear_unit self.use_shared_block_lora = use_shared_block_lora self.use_shared_attention_lora = use_shared_attention_lora self.lora_rank = lora_rank @@ -294,7 +300,6 @@ def __init__( self.use_cache = use_cache self.num_logits_to_keep = num_logits_to_keep - self.use_mem_eff_path = use_mem_eff_path # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) From c1b7647fc0deab5e5ab848f5bc6e715edee16c6d Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 7 Nov 2024 03:29:41 +0000 Subject: [PATCH 13/73] fix config and fwd pass --- .../models/zamba2/configuration_zamba2.py | 24 +++++++----- .../models/zamba2/modeling_zamba2.py | 13 ++++--- .../models/zamba2/modular_zamba2.py | 38 +++++++++++++------ 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index f2457ac400f..2502d7c4de5 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -93,7 +93,7 @@ class Zamba2Config(PretrainedConfig): Expanding factor (relative to hidden_size) used to determine the mamba intermediate size add_bias_linear (`bool`, *optional*, defaults to `False`): Flag indicating whether or not to use bias in various layers - gated_linear_units (`bool`, *optional*, defaults to `False`): + gated_linear_unit (`bool`, *optional*, defaults to `True`): Flag indicating whether or not to use gated MLP use_shared_block_lora (`bool`, *optional*, defaults to `False`): Flag indicating whether or not to add (unshared) LoRA modules to the first layer of the MLP @@ -112,6 +112,7 @@ def __init__( tie_word_embeddings=True, hidden_size=2560, num_hidden_layers=54, + layers_block_type=None, mamba_d_state=64, mamba_d_conv=4, mamba_expand=2, @@ -234,12 +235,15 @@ def __init__( self.use_mem_eff_path = use_mem_eff_path # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) - self.layers_block_type = ( - ["mamba"] - + (["mamba"] * 5 + ["hybrid"]) * 7 - + ["mamba"] * 4 - + ["hybrid"] - + ["mamba"] * 3 - + ["hybrid"] - + ["mamba"] * 2 - ) + if layers_block_type is None: + self.layers_block_type = ( + ["mamba"] + + (["mamba"] * 5 + ["hybrid"]) * 7 + + ["mamba"] * 4 + + ["hybrid"] + + ["mamba"] * 3 + + ["hybrid"] + + ["mamba"] * 2 + ) + else: + self.layers_block_type = layers_block_type diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 0d980e6203d..113cafaf121 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -431,6 +431,7 @@ def forward( bsz, q_len, _ = hidden_states.size() if self.config.use_shared_attention_lora: + layer_idx = self.layer_dic[layer_idx] lora_layer_idx = self.layer_dic[layer_idx] linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx] linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx] @@ -533,6 +534,7 @@ def forward( bsz, q_len, _ = hidden_states.size() if self.config.use_shared_attention_lora: + layer_idx = self.layer_dic[layer_idx] linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] q_lora_output = linear_q_lora_A(hidden_states) @@ -959,11 +961,12 @@ def cuda_kernels_forward( ) # 1D Convolution - hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2) - conv_state = nn.functional.pad( - hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) - ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + if cache_params is not None: + hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2) + conv_state = nn.functional.pad( + hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index e1678ea0740..cb6ea04a312 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -104,6 +104,8 @@ class Zamba2Config(PretrainedConfig): Dimension of the hidden representations. num_hidden_layers (`int`, *optional*, defaults to 54): Number of hidden layers in the model. + layers_block_type (`list`, *optional*): + List of layer types, which can be either "mamba" or "hybrid". mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents. mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel. mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. @@ -150,7 +152,7 @@ class Zamba2Config(PretrainedConfig): The dropout ratio for the attention probabilities. num_mem_blocks (`int`, *optional*, defaults to 1): Number of unshared transformer blocks. - use_shared_block_lora (`bool`, *optional*, defaults to `True`): + use_shared_block_lora (`bool`, *optional*, defaults to `False`): If True, unshared LoRA's will be added to the shared MLP's. use_shared_attention_lora (`bool`, *optional*, defaults to `False`): If True, unshared LoRA's will be added to the q, k, v projectors in the shared attention layers. @@ -191,6 +193,7 @@ def __init__( tie_word_embeddings=True, hidden_size=2560, num_hidden_layers=54, + layers_block_type=None, mamba_d_state=64, mamba_d_conv=4, @@ -217,7 +220,7 @@ def __init__( attention_dropout=0.0, num_mem_blocks=1, - use_shared_block_lora=True, + use_shared_block_lora=False, use_shared_attention_lora=False, lora_rank=128, use_mem_rope=False, @@ -295,15 +298,23 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps - if intermediate_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - self.use_cache = use_cache self.num_logits_to_keep = num_logits_to_keep # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) - self.layers_block_type = ['mamba'] + (['mamba'] * 5 + ['hybrid']) * 7 + ['mamba'] * 4 + ['hybrid'] + ['mamba'] * 3 + ['hybrid'] + ['mamba'] * 2 + if layers_block_type is None: + self.layers_block_type = ( + ["mamba"] + + (["mamba"] * 5 + ["hybrid"]) * 7 + + ["mamba"] * 4 + + ["hybrid"] + + ["mamba"] * 3 + + ["hybrid"] + + ["mamba"] * 2 + ) + else: + self.layers_block_type = layers_block_type def count_mem_blocks_in_config(config: Zamba2Config): @@ -674,6 +685,7 @@ def forward( bsz, q_len, _ = hidden_states.size() if self.config.use_shared_attention_lora: + layer_idx = self.layer_dic[layer_idx] lora_layer_idx = self.layer_dic[layer_idx] linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx] linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx] @@ -775,6 +787,7 @@ def forward( bsz, q_len, _ = hidden_states.size() if self.config.use_shared_attention_lora: + layer_idx = self.layer_dic[layer_idx] linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] q_lora_output = linear_q_lora_A(hidden_states) @@ -1152,12 +1165,13 @@ def cuda_kernels_forward( ) # 1D Convolution - hidden_states_B_C_t = hidden_states_B_C.transpose(1,2) - conv_state = nn.functional.pad( - hidden_states_B_C_t, - (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) - ) - cache_params.conv_states[self.layer_idx].copy_(conv_state) + if cache_params is not None: + hidden_states_B_C_t = hidden_states_B_C.transpose(1,2) + conv_state = nn.functional.pad( + hidden_states_B_C_t, + (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_state) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: hidden_states_B_C = self.act( self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len] From 979b99bf44da4d671240b2a91b2b2a4cccdd5554 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 7 Nov 2024 04:27:59 +0000 Subject: [PATCH 14/73] make fixup fixes --- .../models/mamba2/modeling_mamba2.py | 2 +- .../models/zamba2/modeling_zamba2.py | 21 +-- .../models/zamba2/modular_zamba2.py | 173 ++++++++---------- 3 files changed, 85 insertions(+), 111 deletions(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 8661495dbf6..c312b9b9435 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -44,7 +44,7 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + selective_state_update = None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 113cafaf121..48256cbe4bb 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -57,8 +57,11 @@ if is_mamba_ssm_available(): #### from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined #### added from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import ( #### added + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + ) else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None @@ -66,8 +69,8 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) #### added + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) #### added class Zamba2DynamicCache(DynamicCache): @@ -136,7 +139,7 @@ def reset(self): self.conv_states.zero_() self.ssm_states.zero_() - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.update + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update def update( self, key_states: torch.Tensor, @@ -154,7 +157,7 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.reorder_cache + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): @@ -168,7 +171,7 @@ def reorder_cache(self, beam_idx: torch.LongTensor): device = self.ssm_states[layer_idx].device self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.get_seq_length + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # take any layer that contains cache and not empty tensor @@ -177,12 +180,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.to_legacy_cache def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") @classmethod - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.from_legacy_cache def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") @@ -241,7 +242,6 @@ def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device self.register_buffer("inv_freq", inv_freq, persistent=False) @torch.no_grad() - # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @@ -509,7 +509,6 @@ class Zamba2FlashAttention2(Zamba2Attention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -612,7 +611,7 @@ def forward( value_states, attention_mask, q_len, - is_causal = self.is_causal, + is_causal=self.is_causal, dropout=dropout_rate, softmax_scale=softmax_scale, ) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index cb6ea04a312..64bf2a94ce7 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -14,39 +14,37 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Any, Dict, List, Optional, Tuple, Union from itertools import cycle +from typing import Any, Dict, Optional, Tuple import torch import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import DynamicCache +from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, -) from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, - add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, logging, - replace_return_docstrings, ) from ...utils.import_utils import ( is_causal_conv1d_available, is_mamba_ssm_available, - is_torchdynamo_compiling, +) +from ..zamba.modeling_zamba import ( + ZambaAttention, + ZambaForCausalLM, + ZambaForSequenceClassification, + ZambaMambaDecoderLayer, + ZambaModel, + ZambaRMSNorm, + repeat_kv, ) @@ -63,18 +61,6 @@ is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) -from ..zamba.modeling_zamba import ( - ZambaMambaDecoderLayer, - ZambaAttention, - ZambaForCausalLM, - ZambaForSequenceClassification, - ZambaModel, - ZambaPreTrainedModel, - ZambaRMSNorm, - repeat_kv, -) -from ...configuration_utils import PretrainedConfig - _CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B" @@ -194,7 +180,6 @@ def __init__( hidden_size=2560, num_hidden_layers=54, layers_block_type=None, - mamba_d_state=64, mamba_d_conv=4, mamba_expand=2, @@ -204,28 +189,24 @@ def __init__( time_step_max=0.1, time_step_floor=1e-4, time_step_limit=(0.0, float("inf")), - mamba_dt_rank="auto", n_mamba_heads=1, mamba_proj_bias=False, hidden_mamba_act="silu", use_conv_bias=True, chunk_size=256, - - add_bias_linear=False, + add_bias_linear=False, intermediate_size=None, hidden_act="gelu", num_attention_heads=32, num_key_value_heads=None, attention_dropout=0.0, - num_mem_blocks=1, use_shared_block_lora=False, use_shared_attention_lora=False, lora_rank=128, use_mem_rope=False, rope_theta=10000, - initializer_range=0.02, rms_norm_eps=1e-5, use_cache=True, @@ -233,7 +214,6 @@ def __init__( pad_token_id=0, bos_token_id=1, eos_token_id=2, - use_long_context=False, **kwargs, ): @@ -270,23 +250,22 @@ def __init__( self.mamba_headdim = mamba_headdim self.mamba_ngroups = mamba_ngroups self.n_mamba_heads = n_mamba_heads - self.mamba_conv_bias = mamba_conv_bias self.mamba_proj_bias = mamba_proj_bias self.hidden_mamba_act = hidden_mamba_act self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - + self.use_shared_block_lora = use_shared_block_lora self.use_shared_attention_lora = use_shared_attention_lora self.lora_rank = lora_rank - self.use_long_context=use_long_context + self.use_long_context = use_long_context self.time_step_min = time_step_min self.time_step_max = time_step_max self.time_step_floor = time_step_floor if use_long_context: self.max_position_embeddings = 16384 - + # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads @@ -301,7 +280,6 @@ def __init__( self.use_cache = use_cache self.num_logits_to_keep = num_logits_to_keep - # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) if layers_block_type is None: self.layers_block_type = ( @@ -323,8 +301,8 @@ def count_mem_blocks_in_config(config: Zamba2Config): """ num_gs = 0 for val in config.layers_block_type: - if val == 'hybrid': - num_gs +=1 + if val == "hybrid": + num_gs += 1 return num_gs @@ -333,9 +311,9 @@ def layer_type_list(config: Zamba2Config): Returns list of layer ids containing hybrid layers """ ll = [] - i = 0 + i = 0 for val in config.layers_block_type: - if val == 'hybrid': + if val == "hybrid": ll.append(i) i += 1 return ll @@ -416,7 +394,7 @@ def __init__( ): self.layers_block_type = config.layers_block_type self.transformer_layers = [] - + self.has_previous_state = False self.dtype = dtype self.conv_kernel_size = config.mamba_d_conv @@ -439,7 +417,7 @@ def __init__( ) for i in range(config.num_hidden_layers) } - + for i in range(config.num_hidden_layers): if self.layers_block_type[i] == "hybrid": self.transformer_layers.append(i) @@ -462,8 +440,8 @@ def update_conv_state( def reset(self): self.conv_states.zero_() self.ssm_states.zero_() - - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.update + + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update def update( self, key_states: torch.Tensor, @@ -481,7 +459,7 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.reorder_cache + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): @@ -495,7 +473,7 @@ def reorder_cache(self, beam_idx: torch.LongTensor): device = self.ssm_states[layer_idx].device self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.get_seq_length + # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # take any layer that contains cache and not empty tensor @@ -504,17 +482,14 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.to_legacy_cache def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") @classmethod - # Copied from transformers.models.jamba.modeling_jamba.Zamba2DynamicCache.from_legacy_cache def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Zamba class Zamba2RMSNorm(ZambaRMSNorm): pass @@ -535,7 +510,7 @@ def forward(self, hidden_states, gate=None): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) - + ALL_LAYERNORM_LAYERS.append(Zamba2RMSNorm) @@ -547,14 +522,12 @@ def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device self.max_position_embeddings = max_position_embeddings if config.use_long_context: a = 8 - base = base * a ** (dim / (dim-2)) + base = base * a ** (dim / (dim - 2)) self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - @torch.no_grad() - # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @@ -609,7 +582,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class Zamba2Attention(ZambaAttention): """ - Multi-headed attention from 'Attention Is All You Need' paper. + Multi-headed attention from 'Attention Is All You Need' paper. Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads. @@ -618,11 +591,13 @@ class Zamba2Attention(ZambaAttention): Additionally, replaced attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) - Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this + Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer is tied, un-tied LoRA modules are added to the q, k, v projectors to increase expressivity with a small memory overhead. """ - def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks = None, block_id: int = None): + def __init__( + self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks=None, block_id: int = None + ): super().__init__(config, layer_idx) self.num_fwd_mem_blocks = num_fwd_mem_blocks self.rope_theta = config.rope_theta @@ -637,15 +612,15 @@ def __init__(self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fw self.linear_k_lora_B_list = nn.ModuleList([]) self.linear_v_lora_A_list = nn.ModuleList([]) self.linear_v_lora_B_list = nn.ModuleList([]) - + for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: - linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) - linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) - linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) - linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) - linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias = False) - linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias = False) + linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) + linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) + linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) + linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) else: linear_q_lora_A = nn.Identity() linear_q_lora_B = nn.Identity() @@ -763,7 +738,6 @@ class Zamba2FlashAttention2(Zamba2Attention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -820,7 +794,7 @@ def forward( if self.config.use_mem_rope: cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) @@ -1009,7 +983,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): self.conv_kernel_size = config.mamba_d_conv self.intermediate_size = int(config.mamba_expand * self.hidden_size) self.layer_idx = layer_idx - self.use_conv_bias = config.use_conv_bias # add this with default True + self.use_conv_bias = config.use_conv_bias self.activation = "silu" self.act = nn.SiLU() @@ -1018,9 +992,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): self.num_heads = self.intermediate_size // self.head_dim self.chunk_size = config.chunk_size - self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf")) - self.time_step_min = config.time_step_min # add this, with same default as zamba1 - self.time_step_max = config.time_step_max # add this, with same default as zamba1 + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size self.conv1d = nn.Conv1d( @@ -1122,7 +1096,7 @@ def cuda_kernels_forward( out = self.out_proj(hidden_states)[:, None, ...] # if no cache is found, calling the kernel else: - if attention_mask is not None and not torch.all(attention_mask==1): + if attention_mask is not None and not torch.all(attention_mask == 1): # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 dtype = hidden_states.dtype hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) @@ -1131,7 +1105,7 @@ def cuda_kernels_forward( A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} if attention_mask is not None: - input_not_masked = torch.all(attention_mask==1) + input_not_masked = torch.all(attention_mask == 1) else: input_not_masked = True @@ -1166,10 +1140,9 @@ def cuda_kernels_forward( # 1D Convolution if cache_params is not None: - hidden_states_B_C_t = hidden_states_B_C.transpose(1,2) + hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2) conv_state = nn.functional.pad( - hidden_states_B_C_t, - (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) + hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0) ) cache_params.conv_states[self.layer_idx].copy_(conv_state) if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: @@ -1188,7 +1161,7 @@ def cuda_kernels_forward( [self.intermediate_size, groups_time_state_size, groups_time_state_size], dim=-1, ) - if attention_mask is not None and not torch.all(attention_mask==1): + if attention_mask is not None and not torch.all(attention_mask == 1): # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 dtype = hidden_states.dtype hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) @@ -1421,9 +1394,9 @@ def forward( class Zamba2MLP(nn.Module): - def __init__(self, config: Zamba2Config, num_fwd_mem_blocks = None, block_id: int = None): + def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int = None): """ - This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer + This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer is tied, un-tied LoRA modules are added to the up and gate projectors to increase expressivity with a small memory overhead. """ super().__init__() @@ -1432,34 +1405,36 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks = None, block_id: in self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id self.ffn_intermediate_size = config.intermediate_size - + self.act_fn = ACT2FN[config.hidden_act] + def gated_act_fn(x): x = torch.chunk(x, 2, dim=-1) return self.act_fn(x[0]) * x[1] + self.gated_act_fn = gated_act_fn - self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias = config.add_bias_linear) + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias=config.add_bias_linear) self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) - + if self.config.use_shared_block_lora: self.gate_up_proj_lora_A_list = nn.ModuleList([]) self.gate_up_proj_lora_B_list = nn.ModuleList([]) for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: - gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias = False) - gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias = False) + gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias=False) + gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias=False) else: gate_up_proj_lora_A = nn.Identity() gate_up_proj_lora_B = nn.Identity() self.gate_up_proj_lora_A_list.append(gate_up_proj_lora_A) self.gate_up_proj_lora_B_list.append(gate_up_proj_lora_B) - + layer_block_map = layer_type_list(config) self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} - def forward(self, hidden_state, layer_idx = None): + def forward(self, hidden_state, layer_idx=None): if self.config.use_shared_block_lora: layer_idx = self.layer_dic[layer_idx] gate_up_proj_lora_A = self.gate_up_proj_lora_A_list[layer_idx] @@ -1481,7 +1456,9 @@ def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Option super().__init__() num_gs = count_mem_blocks_in_config(config) self.block_id = block_id - self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id) + self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation]( + config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id + ) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1531,7 +1508,7 @@ def forward( cache_position=cache_position, **kwargs, ) - + hidden_states = self.pre_ff_layernorm(hidden_states) hidden_states = self.feed_forward(hidden_states, layer_idx) @@ -1627,6 +1604,7 @@ def forward( return layer_outputs + ZAMBA2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -1658,7 +1636,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _supports_sdpa = False _supports_cache_class = True # Note: only supports Zamba2DynamicCache _is_stateful = True - + def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv1d)): @@ -1675,8 +1653,7 @@ def _init_weights(self, module): num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim dt = torch.exp( - torch.rand(num_heads) - * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + torch.rand(num_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + math.log(self.config.time_step_min) ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -1772,7 +1749,7 @@ def __init__(self, config: Zamba2Config): Zamba2PreTrainedModel.__init__(self, config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) blocks = [Zamba2AttentionDecoderLayer(config, block_id=k) for k in range(config.num_mem_blocks)] mamba_layers = [] @@ -1811,9 +1788,9 @@ def __init__(self, config: Zamba2Config): self._attn_implementation = config._attn_implementation self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + self.gradient_checkpointing = False - + # Initialize weights and apply final processing self.post_init() @@ -1827,7 +1804,7 @@ def __init__(self, config: Zamba2Config): # Initialize weights and apply final processing self.post_init() - + # Adapted from transformers.models.zamba.modeling_zamba.ZambaForCausalLM.prepare_inputs_for_generation # with `Zamba2DynamicCache` -> `Zamba2DynamicCache` def prepare_inputs_for_generation( @@ -1855,9 +1832,7 @@ def prepare_inputs_for_generation( elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - past_key_values = Zamba2DynamicCache( - self.config, input_ids.shape[0], dtype=self.dtype, device=self.device - ) + past_key_values = Zamba2DynamicCache(self.config, input_ids.shape[0], dtype=self.dtype, device=self.device) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation @@ -1907,4 +1882,4 @@ def __init__(self, config): self._tied_weights_keys = self.model._tied_weights_keys # Initialize weights and apply final processing - self.post_init() \ No newline at end of file + self.post_init() From 9d9b2eb75f8071188c1f520cc9d8d99c122c73c7 Mon Sep 17 00:00:00 2001 From: pglorio Date: Sat, 9 Nov 2024 00:37:56 +0000 Subject: [PATCH 15/73] text_modeling_zamba2 --- .../models/zamba2/configuration_zamba2.py | 112 +-- .../models/zamba2/modeling_zamba2.py | 61 +- .../models/zamba2/modular_zamba2.py | 81 ++- tests/generation/test_utils.py | 1 + tests/models/zamba2/__init__.py | 0 tests/models/zamba2/test_modeling_zamba2.py | 640 ++++++++++++++++++ 6 files changed, 791 insertions(+), 104 deletions(-) create mode 100644 tests/models/zamba2/__init__.py create mode 100644 tests/models/zamba2/test_modeling_zamba2.py diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 2502d7c4de5..08fdca8aed3 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -38,15 +38,48 @@ class Zamba2Config(PretrainedConfig): vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Zamba2Model`] + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. hidden_size (`int`, *optional*, defaults to 2560): Dimension of the hidden representations. - ffn_hidden_size (`int`, *optional*, defaults to 4 * hidden_size): - Dimension of the MLP representations. num_hidden_layers (`int`, *optional*, defaults to 54): Number of hidden layers in the model. + layers_block_type (`list`, *optional*): + List of layer types, which can be either "mamba" or "hybrid". + mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents. + mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel. + mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. + mamba_ngroups (`int`, *optional*, defaults to 8): + Number of groups for the evolution matrices of mamba 2. + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + time_step_limit (`tuple`, *optional*): + Accepted range of time step values. + mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` + n_mamba_heads (`int`, *optional*, defaults to 1): + Number of heads for the evolution matrices of mamba 2. + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + hidden_mamba_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) adjacent to the mamba conv. + chunk_size (`int`, *optional*, defaults to 256): + Size of the chunks that will comprise the sequence. + add_bias_linear (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in various layers + intermediate_size (`int`, *optional*, defaults to 4 * hidden_size): + Dimension of the MLP representations. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the MLP. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer decoder. num_key_value_heads (`int`, *optional*): @@ -56,11 +89,23 @@ class Zamba2Config(PretrainedConfig): converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group. For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf). - mamba_headdim (``, *optional*, defaults to 64): - dimension of each Mamba2 heads (number of heads is set to 1 in this implementation). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_mem_blocks (`int`, *optional*, defaults to 1): + Number of unshared transformer blocks. + use_shared_block_lora (`bool`, *optional*, defaults to `False`): + If True, unshared LoRA's will be added to the shared MLP's. + use_shared_attention_lora (`bool`, *optional*, defaults to `False`): + If True, unshared LoRA's will be added to the q, k, v projectors in the shared attention layers. + lora_rank (`int`, *optional*, defaults to 128): + Rank of the LoRA in the shared MLP and shared attention layers. + use_mem_rope (`bool`, *optional*, defaults to `False`): + If True, includes RoPE in the shared attention layers. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): + rms_norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only @@ -77,29 +122,6 @@ class Zamba2Config(PretrainedConfig): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 2): The id of the "end-of-sequence" token. - sliding_window (`int`, *optional*): - Sliding window attention window size. If not specified, will default to `None`. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - use_mamba_kernels (`bool`, *optional*, defaults to `True`): - Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and - `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if - `True` and kernels are not available - state_size (`int`, *optional*, defaults to 16): - The dimension the mamba state space latents - mamba_d_conv (`int`, *optional*, defaults to 4): - The size of the mamba convolution kernel - mamba_expand (`int`, *optional*, defaults to 2): - Expanding factor (relative to hidden_size) used to determine the mamba intermediate size - add_bias_linear (`bool`, *optional*, defaults to `False`): - Flag indicating whether or not to use bias in various layers - gated_linear_unit (`bool`, *optional*, defaults to `True`): - Flag indicating whether or not to use gated MLP - use_shared_block_lora (`bool`, *optional*, defaults to `False`): - Flag indicating whether or not to add (unshared) LoRA modules to the first layer of the MLP - inside the shared transformer blocks - state_size (`int`, *optional*, defaults to 128): - The rank of the LoRA modules inside the MLP of the shared transformer blocks """ model_type = "zamba2" @@ -116,37 +138,29 @@ def __init__( mamba_d_state=64, mamba_d_conv=4, mamba_expand=2, - mamba_headdim=64, mamba_ngroups=1, time_step_min=0.001, time_step_max=0.1, time_step_floor=1e-4, - time_step_limit=(0.0, float("inf")), + time_step_limit=None, mamba_dt_rank="auto", n_mamba_heads=1, - mamba_conv_bias=True, mamba_proj_bias=False, hidden_mamba_act="silu", - use_mamba_kernels=True, use_conv_bias=True, chunk_size=256, add_bias_linear=False, intermediate_size=None, - gated_linear_unit=True, hidden_act="gelu", num_attention_heads=32, num_key_value_heads=None, - sliding_window=None, attention_dropout=0.0, num_mem_blocks=1, - use_shared_block_lora=True, + use_shared_block_lora=False, use_shared_attention_lora=False, lora_rank=128, - use_mem_eff_path=True, use_mem_rope=False, rope_theta=10000, - attention_hidden_size=None, - attention_head_dim=None, initializer_range=0.02, rms_norm_eps=1e-5, use_cache=True, @@ -176,36 +190,26 @@ def __init__( self.hidden_act = hidden_act self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads - self.sliding_window = sliding_window self.num_mem_blocks = num_mem_blocks self.use_mem_rope = use_mem_rope self.rope_theta = rope_theta - if attention_hidden_size is None: - self.attention_hidden_size = 2 * hidden_size - else: - self.attention_hidden_size = attention_hidden_size - if attention_head_dim is None: - self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads - else: - self.attention_head_dim = attention_head_dim + self.attention_hidden_size = 2 * hidden_size + self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads self.attention_dropout = attention_dropout self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv self.mamba_expand = mamba_expand self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank self.add_bias_linear = add_bias_linear - self.mamba_headdim = mamba_headdim + self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads self.mamba_ngroups = mamba_ngroups self.n_mamba_heads = n_mamba_heads - self.mamba_conv_bias = mamba_conv_bias self.mamba_proj_bias = mamba_proj_bias self.hidden_mamba_act = hidden_mamba_act - self.use_mamba_kernels = use_mamba_kernels self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.gated_linear_unit = gated_linear_unit self.use_shared_block_lora = use_shared_block_lora self.use_shared_attention_lora = use_shared_attention_lora self.lora_rank = lora_rank @@ -227,12 +231,8 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps - if intermediate_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - self.use_cache = use_cache self.num_logits_to_keep = num_logits_to_keep - self.use_mem_eff_path = use_mem_eff_path # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) if layers_block_type is None: @@ -246,4 +246,4 @@ def __init__( + ["mamba"] * 2 ) else: - self.layers_block_type = layers_block_type + self.layers_block_type = layers_block_type \ No newline at end of file diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 48256cbe4bb..9598f6dc497 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -92,6 +92,9 @@ def __init__( ): self.layers_block_type = config.layers_block_type self.transformer_layers = [] + self._modules = {} + self._parameters = {} + self._buffers = {} self.has_previous_state = False self.dtype = dtype @@ -431,7 +434,6 @@ def forward( bsz, q_len, _ = hidden_states.size() if self.config.use_shared_attention_lora: - layer_idx = self.layer_dic[layer_idx] lora_layer_idx = self.layer_dic[layer_idx] linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx] linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx] @@ -810,7 +812,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): self.n_groups = config.mamba_ngroups self.head_dim = config.mamba_headdim - self.num_heads = self.intermediate_size // self.head_dim + self.num_heads = self.config.n_mamba_heads self.chunk_size = config.chunk_size self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf")) @@ -924,7 +926,7 @@ def cuda_kernels_forward( # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) - dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + dt_limit_kwargs = {} if self.time_step_limit == None else {"dt_limit": self.time_step_limit} if attention_mask is not None: input_not_masked = torch.all(attention_mask == 1) else: @@ -1551,9 +1553,9 @@ def _init_weights(self, module): module.A_log._no_weight_decay = True module.D._no_weight_decay = True - num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim + # num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim dt = torch.exp( - torch.rand(num_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + math.log(self.config.time_step_min) ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -1671,19 +1673,42 @@ def __init__(self, config: Zamba2Config): self._tied_weights_keys = [] for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": - prefix_name = f"layers.{layer_id}." - tied_keys = [ - "shared_transf.self_attn.q_proj.weight", - "shared_transf.self_attn.k_proj.weight", - "shared_transf.self_attn.v_proj.weight", - "shared_transf.self_attn.o_proj.weight", - "shared_transf.feed_forward.gate_proj.weight", - "shared_transf.feed_forward.up_proj.weight", - "shared_transf.feed_forward.down_proj.weight", - "shared_transf.input_layernorm.weight", - "shared_transf.pre_ff_layernorm.weight", - ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + block = next(blocks) + if config.num_mem_blocks * len(layer_type_list(config)) > 1: + prefix_name = f"layers.{layer_id}." + tied_keys = [ + "shared_transf.self_attn.q_proj.weight", + "shared_transf.self_attn.k_proj.weight", + "shared_transf.self_attn.v_proj.weight", + "shared_transf.self_attn.o_proj.weight", + "shared_transf.feed_forward.gate_up_proj.weight", + "shared_transf.feed_forward.down_proj.weight", + "shared_transf.input_layernorm.weight", + "shared_transf.pre_ff_layernorm.weight", + ] + self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + if config.use_shared_block_lora: + tied_keys_lora = [] + lora_id = 0 + for _layer_type in self.layers_block_type: + if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: + tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_A_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_B_list.' + str(lora_id) + '.weight') + lora_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] + if config.use_shared_attention_lora: + tied_keys_lora = [] + lora_id = 0 + for _layer_type in self.layers_block_type: + if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: + tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_A_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_A_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_A_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_B_list.' + str(lora_id) + '.weight') + lora_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] layers.append(Zamba2HybridLayer(next(blocks), next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 64bf2a94ce7..002a453b1ea 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -95,8 +95,6 @@ class Zamba2Config(PretrainedConfig): mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents. mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel. mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. - mamba_headdim (`int`, *optional*, defaults to 64): - Dimension of each mamba head. mamba_ngroups (`int`, *optional*, defaults to 8): Number of groups for the evolution matrices of mamba 2. time_step_min (`float`, *optional*, defaults to 0.001): @@ -105,7 +103,7 @@ class Zamba2Config(PretrainedConfig): Maximum `time_step` used to bound `dt_proj.bias`. time_step_floor (`float`, *optional*, defaults to 0.0001): Minimum clamping value of the `dt_proj.bias` layer initialization. - time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`): + time_step_limit (`tuple`, *optional*): Accepted range of time step values. mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` @@ -138,7 +136,7 @@ class Zamba2Config(PretrainedConfig): The dropout ratio for the attention probabilities. num_mem_blocks (`int`, *optional*, defaults to 1): Number of unshared transformer blocks. - use_shared_block_lora (`bool`, *optional*, defaults to `False`): + use_shared_mlp_lora (`bool`, *optional*, defaults to `False`): If True, unshared LoRA's will be added to the shared MLP's. use_shared_attention_lora (`bool`, *optional*, defaults to `False`): If True, unshared LoRA's will be added to the q, k, v projectors in the shared attention layers. @@ -183,12 +181,11 @@ def __init__( mamba_d_state=64, mamba_d_conv=4, mamba_expand=2, - mamba_headdim=64, mamba_ngroups=1, time_step_min=0.001, time_step_max=0.1, time_step_floor=1e-4, - time_step_limit=(0.0, float("inf")), + time_step_limit=None, mamba_dt_rank="auto", n_mamba_heads=1, mamba_proj_bias=False, @@ -202,7 +199,7 @@ def __init__( num_key_value_heads=None, attention_dropout=0.0, num_mem_blocks=1, - use_shared_block_lora=False, + use_shared_mlp_lora=False, use_shared_attention_lora=False, lora_rank=128, use_mem_rope=False, @@ -247,16 +244,16 @@ def __init__( self.mamba_expand = mamba_expand self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank self.add_bias_linear = add_bias_linear - self.mamba_headdim = mamba_headdim self.mamba_ngroups = mamba_ngroups self.n_mamba_heads = n_mamba_heads + self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads self.mamba_proj_bias = mamba_proj_bias self.hidden_mamba_act = hidden_mamba_act self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.use_shared_block_lora = use_shared_block_lora + self.use_shared_mlp_lora = use_shared_mlp_lora self.use_shared_attention_lora = use_shared_attention_lora self.lora_rank = lora_rank self.use_long_context = use_long_context @@ -394,6 +391,9 @@ def __init__( ): self.layers_block_type = config.layers_block_type self.transformer_layers = [] + self._modules = {} + self._parameters = {} + self._buffers = {} self.has_previous_state = False self.dtype = dtype @@ -660,7 +660,6 @@ def forward( bsz, q_len, _ = hidden_states.size() if self.config.use_shared_attention_lora: - layer_idx = self.layer_dic[layer_idx] lora_layer_idx = self.layer_dic[layer_idx] linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx] linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx] @@ -989,7 +988,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): self.n_groups = config.mamba_ngroups self.head_dim = config.mamba_headdim - self.num_heads = self.intermediate_size // self.head_dim + self.num_heads = self.config.n_mamba_heads self.chunk_size = config.chunk_size self.time_step_limit = config.time_step_limit @@ -1103,7 +1102,7 @@ def cuda_kernels_forward( # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) - dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + dt_limit_kwargs = {} if self.time_step_limit is None else {"dt_limit": self.time_step_limit} if attention_mask is not None: input_not_masked = torch.all(attention_mask == 1) else: @@ -1418,7 +1417,7 @@ def gated_act_fn(x): self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias=config.add_bias_linear) self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) - if self.config.use_shared_block_lora: + if self.config.use_shared_mlp_lora: self.gate_up_proj_lora_A_list = nn.ModuleList([]) self.gate_up_proj_lora_B_list = nn.ModuleList([]) for i in range(self.num_fwd_mem_blocks): @@ -1435,7 +1434,7 @@ def gated_act_fn(x): self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx=None): - if self.config.use_shared_block_lora: + if self.config.use_shared_mlp_lora: layer_idx = self.layer_dic[layer_idx] gate_up_proj_lora_A = self.gate_up_proj_lora_A_list[layer_idx] gate_up_proj_lora_B = self.gate_up_proj_lora_B_list[layer_idx] @@ -1651,9 +1650,8 @@ def _init_weights(self, module): module.A_log._no_weight_decay = True module.D._no_weight_decay = True - num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim dt = torch.exp( - torch.rand(num_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + math.log(self.config.time_step_min) ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -1768,20 +1766,43 @@ def __init__(self, config: Zamba2Config): self._tied_weights_keys = [] for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": - prefix_name = f"layers.{layer_id}." - tied_keys = [ - "shared_transf.self_attn.q_proj.weight", - "shared_transf.self_attn.k_proj.weight", - "shared_transf.self_attn.v_proj.weight", - "shared_transf.self_attn.o_proj.weight", - "shared_transf.feed_forward.gate_proj.weight", - "shared_transf.feed_forward.up_proj.weight", - "shared_transf.feed_forward.down_proj.weight", - "shared_transf.input_layernorm.weight", - "shared_transf.pre_ff_layernorm.weight", - ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - layers.append(Zamba2HybridLayer(next(blocks), next(linear_layers), next(mamba_layers))) + block = next(blocks) + if config.num_mem_blocks * len(layer_type_list(config)) > 1: + prefix_name = f"layers.{layer_id}." + tied_keys = [ + "shared_transf.self_attn.q_proj.weight", + "shared_transf.self_attn.k_proj.weight", + "shared_transf.self_attn.v_proj.weight", + "shared_transf.self_attn.o_proj.weight", + "shared_transf.feed_forward.gate_up_proj.weight", + "shared_transf.feed_forward.down_proj.weight", + "shared_transf.input_layernorm.weight", + "shared_transf.pre_ff_layernorm.weight", + ] + self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + if config.use_shared_mlp_lora: + tied_keys_lora = [] + lora_id = 0 + for _layer_type in self.layers_block_type: + if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: + tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_A_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_B_list.' + str(lora_id) + '.weight') + lora_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] + if config.use_shared_attention_lora: + tied_keys_lora = [] + lora_id = 0 + for _layer_type in self.layers_block_type: + if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: + tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_A_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_A_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_A_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_B_list.' + str(lora_id) + '.weight') + lora_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] + layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) self.layers = nn.ModuleList(layers) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cbe851e97e9..c351c2a5a43 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2271,6 +2271,7 @@ def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1 "mamba", "xlnet", "zamba", + "zamba2", ) has_standard_cache = not any( model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache diff --git a/tests/models/zamba2/__init__.py b/tests/models/zamba2/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py new file mode 100644 index 00000000000..5877a5398a7 --- /dev/null +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -0,0 +1,640 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Zamba model.""" + +import math +import tempfile +import unittest + +import pytest +from parameterized import parameterized + +from transformers import AutoTokenizer, Zamba2Config, is_torch_available +from transformers.testing_utils import ( + require_bitsandbytes, + require_flash_attn, + require_torch, + require_torch_gpu, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + Zamba2ForCausalLM, + Zamba2ForSequenceClassification, + Zamba2Model, + ) + from transformers.models.zamba2.modeling_zamba2 import ( + Zamba2DynamicCache, + ) + + +class Zamba2ModelTester: + def __init__( + self, + parent, + batch_size=14, + seq_length=7, + is_training=True, + use_input_mask=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + mamba_d_state=2, + chunk_size=8, + mamba_dt_rank="auto", + num_hidden_layers=2, + num_attention_heads=2, + n_mamba_heads=8, + mamba_ngroups=8, + intermediate_size=32, + hidden_act="gelu", + hidden_mamba_act="silu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + layers_block_type = ['mamba', 'hybrid'], + num_mem_blocks=1, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.mamba_dt_rank = mamba_dt_rank + self.mamba_d_state = mamba_d_state + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_mamba_heads = n_mamba_heads + self.mamba_ngroups = mamba_ngroups + self.chunk_size = chunk_size + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_mamba_act = hidden_mamba_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.layers_block_type = layers_block_type + self.num_mem_blocks = num_mem_blocks + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return Zamba2Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + mamba_dt_rank=self.mamba_dt_rank, + mamba_d_state = self.mamba_d_state, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + n_mamba_heads=self.n_mamba_heads, + intermediate_size=self.intermediate_size, + chunk_size=self.chunk_size, + hidden_act=self.hidden_act, + mamba_ngroups=self.mamba_ngroups, + hidden_mamba_act=self.hidden_mamba_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=True, + initializer_range=self.initializer_range, + use_mamba_kernels=False, + layers_block_type=self.layers_block_type, + num_mem_blocks=self.num_mem_blocks, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + config.is_decoder = True + + return ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = Zamba2Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = Zamba2ForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids, labels=token_labels) + result = model(input_ids) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + config.is_decoder = True + config.add_cross_attention = False + model = Zamba2ForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + # Attention: Zamba2 needs the cache to be initialized to return a cache! + past_key_values = Zamba2DynamicCache( + config, input_ids.shape[0], model.dtype, device=model.device + ) + outputs = model( + input_ids, + attention_mask=input_mask, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 1), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + cache_position=torch.arange( + input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device + ), + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -1:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def create_and_check_for_sequence_classification( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = Zamba2ForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + test_torchscript = False + all_model_classes = ( + ( + Zamba2Model, + Zamba2ForCausalLM, + Zamba2ForSequenceClassification, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (Zamba2ForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": Zamba2Model, + "text-classification": Zamba2ForSequenceClassification, + "text-generation": Zamba2ForCausalLM, + "zero-shot": Zamba2ForSequenceClassification, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + + def setUp(self): + self.model_tester = Zamba2ModelTester(self) + self.config_tester = ConfigTester(self, config_class=Zamba2Config, hidden_size=37) + + @unittest.skip("position_ids cannot be used to pad due to Mamba2 layers") + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_casual_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_initialization(self): + r""" + Overriding the test_initialization test as the A_log and D params of the Mamba block are initialized differently + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + if "A_log" in name: + A = torch.arange(1, config.n_mamba_heads + 1, dtype=torch.float32)[None, :] + self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5)) + elif "D" in name: + # check if it's a ones like + self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) + elif "dt_bias" in name: + dt = torch.exp( + torch.tensor([0, 1]) * (math.log(config.time_step_max) - math.log(config.time_step_min)) + + math.log(config.time_step_min) + ).clamp(min=config.time_step_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + if param.requires_grad: + self.assertTrue(param.data.max().item() <= inv_dt[1]) + self.assertTrue(param.data.min().item() >= inv_dt[0]) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_mismatched_shapes_have_properly_initialized_weights(self): + r""" + Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the + Mamba block are initialized differently and we tested that in test_initialization + """ + self.skipTest("Cumbersome and redundant for Zamba2") + + def test_attention_outputs(self): + r""" + Overriding the test_attention_outputs test as the Zamba2 model outputs attention only for its attention layers + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def _get_input_ids_and_config(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + ( + config, + input_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + return config, input_ids, input_mask + + def test_left_padding_compatibility(self): + r""" + Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences + effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value. + """ + import inspect + # NOTE: left-padding results in small numerical differences. This is expected. + # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 + + # First, filter out models that don't support left padding - generative and decoder-only. + # Zamba2 is a decoder-only architecture + decoder_only_classes = self.all_generative_model_classes + + # Then, test left-padding + def _prepare_model_kwargs(input_ids, attention_mask, signature): + model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} + if "position_ids" in signature: + position_ids = torch.cumsum(attention_mask, dim=-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + model_kwargs["position_ids"] = position_ids + if "cache_position" in signature: + cache_position = torch.arange(input_ids.shape[-1], device=torch_device) + model_kwargs["cache_position"] = cache_position + return model_kwargs + + for model_class in decoder_only_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + signature = inspect.signature(model.forward).parameters.keys() + + # Without padding + model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) + next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] + + # With left-padding (length 32) + pad_size = (input_ids.shape[0], 32) + padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id + padded_input_ids = torch.cat((padding, input_ids), dim=1) + padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) + model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) + next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] + + # They should result in very similar logits + self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3)) + + @require_flash_attn + @require_torch_gpu + @require_bitsandbytes + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_fp32_ln(self): + r""" + Overriding the test_flash_attn_2_fp32_ln test as the Zamba2 model, like Mixtral, doesn't support + right padding + use cache with FA2 + """ + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_input = inputs_dict[model.main_input_name] + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Zamba2 does not support right padding + use_cache with FA2. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + load_in_4bit=True, + ) + + for _, param in model.named_parameters(): + # upcast only layer norms + if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): + param.data = param.data.to(torch.float32) + + _ = model(dummy_input) + # with attention mask + _ = model(dummy_input, attention_mask=dummy_attention_mask) + + @require_flash_attn + @require_torch_gpu + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + r""" + Overriding the test_flash_attn_2_inference_padding_right test as the Zamba2 model, like Mixtral, doesn't support + right padding + use cache with FA2 + """ + self.skipTest(reason="Zamba2 flash attention does not support right padding") + + @unittest.skip(reason="Zamba2 has its own special cache type") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + + +@require_torch +class Zamba2ModelIntegrationTest(unittest.TestCase): + model = None + tokenizer = None + + @classmethod + @slow + def setUpClass(cls): + model_id = "Zyphra/Zamba2-1.2B" + cls.model = Zamba2ForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True, revision="PR" + ) + cls.tokenizer = AutoTokenizer.from_pretrained(model_id, revision="PR") + + @slow + def test_simple_generate(self): + self.model.to(torch_device) + + input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[ + "input_ids" + ].to(torch_device) + out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10) + output_sentence = self.tokenizer.decode(out[0, :]) + self.assertEqual( + output_sentence, + " Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for", + ) + + with torch.no_grad(): + logits = self.model(input_ids=input_ids).logits.to(dtype=torch.float32) + + EXPECTED_LOGITS_NO_GRAD = torch.tensor( + [ + -5.9587, 10.5152, 7.0382, -2.8728, -4.8143, -4.8142, -4.8142, -4.8144, + -4.8143, -4.8143, -4.8142, -4.8142, 6.0185, 18.0037, -4.8142, -4.8144, + -4.8143, -4.8142, -4.8143, -4.8143, -4.8143, -4.8143, -4.8142, -4.8143, + -4.8144, -4.8143, -4.8143, -4.8141, -4.8142, -4.8142, -4.8142, -4.8144, + -4.8143, -4.8143, -4.8143, -4.8142, -4.8144, -4.8144, -4.8142, -4.8142 + ] + , dtype=torch.float32) # fmt: skip + torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3) + + @slow + def test_simple_batched_generate_with_padding(self): + self.model.to(torch_device) + + inputs = self.tokenizer( + ["Hey how are you doing on this lovely evening?", "When did the Roman empire "], padding=True, return_tensors="pt" + ).to(torch_device) + out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10) + output_sentences = self.tokenizer.batch_decode(out) + self.assertEqual( + output_sentences[0], + " Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for", + ) + + self.assertEqual( + output_sentences[1], + "[PAD][PAD][PAD][PAD] When did the Roman empire 1st fall?\nThe Roman Empire fell in", + ) + + with torch.no_grad(): + logits = self.model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits.to(dtype=torch.float32) + + EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( + [ + -5.9611, 10.5208, 7.0411, -2.8743, -4.8167, -4.8167, -4.8167, -4.8168, + -4.8167, -4.8167, -4.8167, -4.8166, 6.0218, 18.0062, -4.8167, -4.8168, + -4.8167, -4.8167, -4.8167, -4.8168, -4.8168, -4.8168, -4.8167, -4.8167, + -4.8168, -4.8167, -4.8167, -4.8165, -4.8167, -4.8167, -4.8167, -4.8169, + -4.8168, -4.8168, -4.8168, -4.8166, -4.8169, -4.8168, -4.8167, -4.8167 + ] + , dtype=torch.float32) # fmt: skip + + EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( + [ + 0.1966, 6.3449, 3.8350, -5.7291, -6.5106, -6.5104, -6.5103, -6.5104, + -6.5103, -6.5104, -6.5106, -6.5105, 7.8700, 13.5434, -6.5104, -6.5096, + -6.5106, -6.5102, -6.5106, -6.5106, -6.5105, -6.5106, -6.5104, -6.5106, + -6.5105, -6.5106, -6.5106, -6.5113, -6.5102, -6.5105, -6.5108, -6.5105, + -6.5104, -6.5106, -6.5106, -6.5104, -6.5106, -6.5107, -6.5103, -6.5105 ] + , dtype=torch.float32) # fmt: skip + + torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3) From 549d4cb4de178cbc83a4aab1078934742c219254 Mon Sep 17 00:00:00 2001 From: pglorio Date: Sat, 9 Nov 2024 01:09:41 +0000 Subject: [PATCH 16/73] small fixes --- src/transformers/models/zamba2/configuration_zamba2.py | 6 +++--- src/transformers/models/zamba2/modeling_zamba2.py | 6 +++--- src/transformers/models/zamba2/modular_zamba2.py | 4 +--- tests/models/zamba2/test_modeling_zamba2.py | 2 +- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 08fdca8aed3..2acd266a425 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -93,7 +93,7 @@ class Zamba2Config(PretrainedConfig): The dropout ratio for the attention probabilities. num_mem_blocks (`int`, *optional*, defaults to 1): Number of unshared transformer blocks. - use_shared_block_lora (`bool`, *optional*, defaults to `False`): + use_shared_mlp_lora (`bool`, *optional*, defaults to `False`): If True, unshared LoRA's will be added to the shared MLP's. use_shared_attention_lora (`bool`, *optional*, defaults to `False`): If True, unshared LoRA's will be added to the q, k, v projectors in the shared attention layers. @@ -156,7 +156,7 @@ def __init__( num_key_value_heads=None, attention_dropout=0.0, num_mem_blocks=1, - use_shared_block_lora=False, + use_shared_mlp_lora=False, use_shared_attention_lora=False, lora_rank=128, use_mem_rope=False, @@ -210,7 +210,7 @@ def __init__( self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.use_shared_block_lora = use_shared_block_lora + self.use_shared_mlp_lora = use_shared_mlp_lora self.use_shared_attention_lora = use_shared_attention_lora self.lora_rank = lora_rank self.use_long_context = use_long_context diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 9598f6dc497..a92c59ac81e 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1248,7 +1248,7 @@ def gated_act_fn(x): self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias=config.add_bias_linear) self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) - if self.config.use_shared_block_lora: + if self.config.use_shared_mlp_lora: self.gate_up_proj_lora_A_list = nn.ModuleList([]) self.gate_up_proj_lora_B_list = nn.ModuleList([]) for i in range(self.num_fwd_mem_blocks): @@ -1265,7 +1265,7 @@ def gated_act_fn(x): self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx=None): - if self.config.use_shared_block_lora: + if self.config.use_shared_mlp_lora: layer_idx = self.layer_dic[layer_idx] gate_up_proj_lora_A = self.gate_up_proj_lora_A_list[layer_idx] gate_up_proj_lora_B = self.gate_up_proj_lora_B_list[layer_idx] @@ -1687,7 +1687,7 @@ def __init__(self, config: Zamba2Config): "shared_transf.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - if config.use_shared_block_lora: + if config.use_shared_mlp_lora: tied_keys_lora = [] lora_id = 0 for _layer_type in self.layers_block_type: diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 002a453b1ea..7facf6c1701 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -325,7 +325,6 @@ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): Assumes that we only have tensors of either size 4 or 3 """ - # pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if input_tensor.ndim == 4 else (0, 0, 0, pad_size, 0, 0) return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) @@ -341,7 +340,6 @@ def reshape_into_chunks(input_tensor, pad_size, chunk_size): # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] input_tensor = pad_tensor_by_size(input_tensor, pad_size) - # if len(input_tensor.shape) == 3: if input_tensor.ndim == 3: # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) @@ -1117,7 +1115,7 @@ def cuda_kernels_forward( A, D=self.D, chunk_size=self.chunk_size, - seq_idx=None, # was seq_idx + seq_idx=None, activation=self.activation, rmsnorm_weight=self.norm.weight, rmsnorm_eps=self.norm.variance_epsilon, diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 5877a5398a7..cdfb84b40bb 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -68,7 +68,7 @@ def __init__( num_attention_heads=2, n_mamba_heads=8, mamba_ngroups=8, - intermediate_size=32, + intermediate_size=16, hidden_act="gelu", hidden_mamba_act="silu", hidden_dropout_prob=0.1, From 987bba9f8de85eb7a85ebbd42a37dfb1dbe73cde Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 11 Nov 2024 06:24:49 +0000 Subject: [PATCH 17/73] make fixup fixes --- docs/source/en/perf_infer_gpu_one.md | 3 +- .../models/zamba2/configuration_zamba2.py | 59 ++--- .../models/zamba2/modeling_zamba2.py | 213 +++++++++++++----- .../models/zamba2/modular_zamba2.py | 96 ++++---- tests/models/zamba2/test_modeling_zamba2.py | 18 +- 5 files changed, 238 insertions(+), 151 deletions(-) diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 67bd31fdaee..ce984574963 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -101,6 +101,7 @@ FlashAttention-2 is currently supported for the following architectures: * [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) * [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) * [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) +* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2) You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request. @@ -304,7 +305,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel) * [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel) * [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel) - +* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2) FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models. diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 2acd266a425..5463ca6e8f5 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -19,7 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math + from ...configuration_utils import PretrainedConfig @@ -30,19 +30,16 @@ class Zamba2Config(PretrainedConfig): Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Zamba2 model. + [Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - - Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Zamba2Model`] max_position_embeddings (`int`, *optional*, defaults to 4096): The maximum sequence length that this model might ever be used with. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the - model has a output word embedding layer. hidden_size (`int`, *optional*, defaults to 2560): Dimension of the hidden representations. num_hidden_layers (`int`, *optional*, defaults to 54): @@ -52,7 +49,7 @@ class Zamba2Config(PretrainedConfig): mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents. mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel. mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. - mamba_ngroups (`int`, *optional*, defaults to 8): + mamba_ngroups (`int`, *optional*, defaults to 1): Number of groups for the evolution matrices of mamba 2. time_step_min (`float`, *optional*, defaults to 0.001): Minimum `time_step` used to bound `dt_proj.bias`. @@ -62,16 +59,10 @@ class Zamba2Config(PretrainedConfig): Minimum clamping value of the `dt_proj.bias` layer initialization. time_step_limit (`tuple`, *optional*): Accepted range of time step values. - mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): - Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` n_mamba_heads (`int`, *optional*, defaults to 1): Number of heads for the evolution matrices of mamba 2. use_conv_bias (`bool`, *optional*, defaults to `True`): Whether or not to use bias in the convolution layer of the mixer block. - mamba_proj_bias (`bool`, *optional*, defaults to `False`): - Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block - hidden_mamba_act (`str`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) adjacent to the mamba conv. chunk_size (`int`, *optional*, defaults to 256): Size of the chunks that will comprise the sequence. add_bias_linear (`bool`, *optional*, defaults to `False`): @@ -101,11 +92,11 @@ class Zamba2Config(PretrainedConfig): Rank of the LoRA in the shared MLP and shared attention layers. use_mem_rope (`bool`, *optional*, defaults to `False`): If True, includes RoPE in the shared attention layers. - rope_theta (`float`, *optional*, defaults to 10000.0): + rope_theta (`float`, *optional*, defaults to `10000.0`): The base period of the RoPE embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-5): + rms_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only @@ -122,6 +113,16 @@ class Zamba2Config(PretrainedConfig): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 2): The id of the "end-of-sequence" token. + use_long_context (`bool`, *optional*, defaults to `False`): + Activates the context-extended version of Zamba by modifying RoPE. + ```python + >>> from transformers import Zamba2Model, Zamba2Config + >>> # Initializing a Zamba2-2.7B style configuration + >>> configuration = Zamba2Config() + >>> # Initializing a model from the Zamba2-2.7B style configuration + >>> model = Zamba2Model(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config """ model_type = "zamba2" @@ -131,7 +132,6 @@ def __init__( self, vocab_size=32000, max_position_embeddings=4096, - tie_word_embeddings=True, hidden_size=2560, num_hidden_layers=54, layers_block_type=None, @@ -143,10 +143,7 @@ def __init__( time_step_max=0.1, time_step_floor=1e-4, time_step_limit=None, - mamba_dt_rank="auto", n_mamba_heads=1, - mamba_proj_bias=False, - hidden_mamba_act="silu", use_conv_bias=True, chunk_size=256, add_bias_linear=False, @@ -175,13 +172,10 @@ def __init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, **kwargs, ) - self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings - self.tie_word_embeddings = tie_word_embeddings self.hidden_size = hidden_size if intermediate_size is None: self.intermediate_size = 4 * hidden_size @@ -199,17 +193,13 @@ def __init__( self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv self.mamba_expand = mamba_expand - self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank self.add_bias_linear = add_bias_linear - self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads self.mamba_ngroups = mamba_ngroups self.n_mamba_heads = n_mamba_heads - self.mamba_proj_bias = mamba_proj_bias - self.hidden_mamba_act = hidden_mamba_act + self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.use_shared_mlp_lora = use_shared_mlp_lora self.use_shared_attention_lora = use_shared_attention_lora self.lora_rank = lora_rank @@ -219,21 +209,12 @@ def __init__( self.time_step_floor = time_step_floor if use_long_context: self.max_position_embeddings = 16384 - - # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads - self.num_attention_heads = num_attention_heads self.kv_channels = self.hidden_size // self.num_attention_heads self.num_query_groups = self.num_attention_heads - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - - self.use_cache = use_cache - self.num_logits_to_keep = num_logits_to_keep - # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) if layers_block_type is None: self.layers_block_type = ( @@ -246,4 +227,8 @@ def __init__( + ["mamba"] * 2 ) else: - self.layers_block_type = layers_block_type \ No newline at end of file + self.layers_block_type = layers_block_type + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index a92c59ac81e..e78c155a1f2 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -56,12 +56,8 @@ if is_mamba_ssm_available(): - #### from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import ( #### added - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - ) else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None @@ -70,7 +66,32 @@ else: causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) #### added +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) + + +logger = logging.get_logger(__name__) + + +class Zamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Zamba2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Zamba2DynamicCache(DynamicCache): @@ -191,29 +212,6 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") -logger = logging.get_logger(__name__) - - -class Zamba2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Zamba2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - class MambaRMSNormGated(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() @@ -273,6 +271,92 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + self.dtype = dtype + self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + intermediate_size = config.mamba_expand * config.hidden_size + ssm_state_size = config.mamba_d_state + conv_kernel_size = config.mamba_d_conv + self.n_mamba_heads = config.n_mamba_heads + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + self._modules = {} + self._parameters = {} + self._buffers = {} + for i in range(config.num_hidden_layers): + self.conv_states += [ + torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + ] + cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) + self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] + if self.layers_block_type[i] == "hybrid": + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + def layer_type_list(config: Zamba2Config): """ Returns list of layer ids containing hybrid layers @@ -518,7 +602,6 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - self.is_causal = True def forward( self, @@ -613,9 +696,9 @@ def forward( value_states, attention_mask, q_len, - is_causal=self.is_causal, dropout=dropout_rate, softmax_scale=softmax_scale, + is_causal=self.is_causal, ) attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() @@ -743,7 +826,6 @@ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): Assumes that we only have tensors of either size 4 or 3 """ - # pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if input_tensor.ndim == 4 else (0, 0, 0, pad_size, 0, 0) return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) @@ -759,7 +841,6 @@ def reshape_into_chunks(input_tensor, pad_size, chunk_size): # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] input_tensor = pad_tensor_by_size(input_tensor, pad_size) - # if len(input_tensor.shape) == 3: if input_tensor.ndim == 3: # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) @@ -806,7 +887,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): self.conv_kernel_size = config.mamba_d_conv self.intermediate_size = int(config.mamba_expand * self.hidden_size) self.layer_idx = layer_idx - self.use_conv_bias = config.use_conv_bias # add this with default True + self.use_conv_bias = config.use_conv_bias self.activation = "silu" self.act = nn.SiLU() @@ -815,9 +896,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): self.num_heads = self.config.n_mamba_heads self.chunk_size = config.chunk_size - self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf")) - self.time_step_min = config.time_step_min # add this, with same default as zamba1 - self.time_step_max = config.time_step_max # add this, with same default as zamba1 + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size self.conv1d = nn.Conv1d( @@ -926,7 +1007,7 @@ def cuda_kernels_forward( # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) - dt_limit_kwargs = {} if self.time_step_limit == None else {"dt_limit": self.time_step_limit} + dt_limit_kwargs = {} if self.time_step_limit is None else {"dt_limit": self.time_step_limit} if attention_mask is not None: input_not_masked = torch.all(attention_mask == 1) else: @@ -941,7 +1022,7 @@ def cuda_kernels_forward( A, D=self.D, chunk_size=self.chunk_size, - seq_idx=None, # was seq_idx + seq_idx=None, activation=self.activation, rmsnorm_weight=self.norm.weight, rmsnorm_eps=self.norm.variance_epsilon, @@ -1379,7 +1460,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1390,7 +1471,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states + past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1553,9 +1634,9 @@ def _init_weights(self, module): module.A_log._no_weight_decay = True module.D._no_weight_decay = True - # num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim dt = torch.exp( - torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + torch.rand(self.config.n_mamba_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + math.log(self.config.time_step_min) ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -1604,14 +1685,14 @@ def _init_weights(self, module): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`Zamba2DynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A Zamba2DynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `Zamba2DynamicCache` class for more details. + See the `HybridMambaAttentionDynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1692,8 +1773,12 @@ def __init__(self, config: Zamba2Config): lora_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append( + "shared_transf.feed_forward.gate_up_proj_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.feed_forward.gate_up_proj_lora_B_list." + str(lora_id) + ".weight" + ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] if config.use_shared_attention_lora: @@ -1701,15 +1786,27 @@ def __init__(self, config: Zamba2Config): lora_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_B_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_B_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append( + "shared_transf.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" + ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] - layers.append(Zamba2HybridLayer(next(blocks), next(linear_layers), next(mamba_layers))) + layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) self.layers = nn.ModuleList(layers) @@ -1734,7 +1831,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Zamba2DynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1771,7 +1868,7 @@ def forward( if use_cache and past_key_values is None: logger.warning_once( - "Zamba2 requires an initialized `Zamba2DynamicCache` to return a cache. None was " + "Zamba2 requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " "provided, so no cache will be returned." ) @@ -1917,7 +2014,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Zamba2DynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 7facf6c1701..3b254481a90 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -73,19 +73,16 @@ class Zamba2Config(PretrainedConfig): Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the Zamba2 model. + [Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B) + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - - Args: vocab_size (`int`, *optional*, defaults to 32000): Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Zamba2Model`] max_position_embeddings (`int`, *optional*, defaults to 4096): The maximum sequence length that this model might ever be used with. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the - model has a output word embedding layer. hidden_size (`int`, *optional*, defaults to 2560): Dimension of the hidden representations. num_hidden_layers (`int`, *optional*, defaults to 54): @@ -95,7 +92,7 @@ class Zamba2Config(PretrainedConfig): mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents. mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel. mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. - mamba_ngroups (`int`, *optional*, defaults to 8): + mamba_ngroups (`int`, *optional*, defaults to 1): Number of groups for the evolution matrices of mamba 2. time_step_min (`float`, *optional*, defaults to 0.001): Minimum `time_step` used to bound `dt_proj.bias`. @@ -105,16 +102,10 @@ class Zamba2Config(PretrainedConfig): Minimum clamping value of the `dt_proj.bias` layer initialization. time_step_limit (`tuple`, *optional*): Accepted range of time step values. - mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): - Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` n_mamba_heads (`int`, *optional*, defaults to 1): Number of heads for the evolution matrices of mamba 2. use_conv_bias (`bool`, *optional*, defaults to `True`): Whether or not to use bias in the convolution layer of the mixer block. - mamba_proj_bias (`bool`, *optional*, defaults to `False`): - Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block - hidden_mamba_act (`str`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) adjacent to the mamba conv. chunk_size (`int`, *optional*, defaults to 256): Size of the chunks that will comprise the sequence. add_bias_linear (`bool`, *optional*, defaults to `False`): @@ -144,11 +135,11 @@ class Zamba2Config(PretrainedConfig): Rank of the LoRA in the shared MLP and shared attention layers. use_mem_rope (`bool`, *optional*, defaults to `False`): If True, includes RoPE in the shared attention layers. - rope_theta (`float`, *optional*, defaults to 10000.0): + rope_theta (`float`, *optional*, defaults to `10000.0`): The base period of the RoPE embeddings. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-5): + rms_norm_eps (`float`, *optional*, defaults to 1e-05): The epsilon used by the rms normalization layers. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only @@ -165,6 +156,16 @@ class Zamba2Config(PretrainedConfig): The id of the "beginning-of-sequence" token. eos_token_id (`int`, *optional*, defaults to 2): The id of the "end-of-sequence" token. + use_long_context (`bool`, *optional*, defaults to `False`): + Activates the context-extended version of Zamba by modifying RoPE. + ```python + >>> from transformers import Zamba2Model, Zamba2Config + >>> # Initializing a Zamba2-2.7B style configuration + >>> configuration = Zamba2Config() + >>> # Initializing a model from the Zamba2-2.7B style configuration + >>> model = Zamba2Model(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config """ model_type = "zamba2" @@ -174,7 +175,6 @@ def __init__( self, vocab_size=32000, max_position_embeddings=4096, - tie_word_embeddings=True, hidden_size=2560, num_hidden_layers=54, layers_block_type=None, @@ -186,10 +186,7 @@ def __init__( time_step_max=0.1, time_step_floor=1e-4, time_step_limit=None, - mamba_dt_rank="auto", n_mamba_heads=1, - mamba_proj_bias=False, - hidden_mamba_act="silu", use_conv_bias=True, chunk_size=256, add_bias_linear=False, @@ -218,13 +215,10 @@ def __init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, **kwargs, ) - self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings - self.tie_word_embeddings = tie_word_embeddings self.hidden_size = hidden_size if intermediate_size is None: self.intermediate_size = 4 * hidden_size @@ -242,17 +236,13 @@ def __init__( self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv self.mamba_expand = mamba_expand - self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank self.add_bias_linear = add_bias_linear self.mamba_ngroups = mamba_ngroups self.n_mamba_heads = n_mamba_heads self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads - self.mamba_proj_bias = mamba_proj_bias - self.hidden_mamba_act = hidden_mamba_act self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.use_shared_mlp_lora = use_shared_mlp_lora self.use_shared_attention_lora = use_shared_attention_lora self.lora_rank = lora_rank @@ -262,21 +252,12 @@ def __init__( self.time_step_floor = time_step_floor if use_long_context: self.max_position_embeddings = 16384 - - # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads - self.num_attention_heads = num_attention_heads self.kv_channels = self.hidden_size // self.num_attention_heads self.num_query_groups = self.num_attention_heads - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - - self.use_cache = use_cache - self.num_logits_to_keep = num_logits_to_keep - # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) if layers_block_type is None: self.layers_block_type = ( @@ -290,6 +271,14 @@ def __init__( ) else: self.layers_block_type = layers_block_type + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + +class Zamba2RMSNorm(ZambaRMSNorm): + pass def count_mem_blocks_in_config(config: Zamba2Config): @@ -488,10 +477,6 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") -class Zamba2RMSNorm(ZambaRMSNorm): - pass - - class MambaRMSNormGated(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() @@ -1649,7 +1634,8 @@ def _init_weights(self, module): module.D._no_weight_decay = True dt = torch.exp( - torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + torch.rand(self.config.n_mamba_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + math.log(self.config.time_step_min) ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -1783,8 +1769,12 @@ def __init__(self, config: Zamba2Config): lora_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append( + "shared_transf.feed_forward.gate_up_proj_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.feed_forward.gate_up_proj_lora_B_list." + str(lora_id) + ".weight" + ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] if config.use_shared_attention_lora: @@ -1792,12 +1782,24 @@ def __init__(self, config: Zamba2Config): lora_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_A_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_B_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_B_list.' + str(lora_id) + '.weight') - tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_B_list.' + str(lora_id) + '.weight') + tied_keys_lora.append( + "shared_transf.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" + ) + tied_keys_lora.append( + "shared_transf.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" + ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index cdfb84b40bb..e3ca547923b 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -80,7 +80,7 @@ def __init__( num_labels=3, num_choices=4, scope=None, - layers_block_type = ['mamba', 'hybrid'], + layers_block_type=["mamba", "hybrid"], num_mem_blocks=1, ): self.parent = parent @@ -137,7 +137,7 @@ def get_config(self): vocab_size=self.vocab_size, hidden_size=self.hidden_size, mamba_dt_rank=self.mamba_dt_rank, - mamba_d_state = self.mamba_d_state, + mamba_d_state=self.mamba_d_state, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, n_mamba_heads=self.n_mamba_heads, @@ -221,9 +221,7 @@ def create_and_check_decoder_model_past_large_inputs( # first forward pass # Attention: Zamba2 needs the cache to be initialized to return a cache! - past_key_values = Zamba2DynamicCache( - config, input_ids.shape[0], model.dtype, device=model.device - ) + past_key_values = Zamba2DynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, @@ -600,7 +598,9 @@ def test_simple_batched_generate_with_padding(self): self.model.to(torch_device) inputs = self.tokenizer( - ["Hey how are you doing on this lovely evening?", "When did the Roman empire "], padding=True, return_tensors="pt" + ["Hey how are you doing on this lovely evening?", "When did the Roman empire "], + padding=True, + return_tensors="pt", ).to(torch_device) out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10) output_sentences = self.tokenizer.batch_decode(out) @@ -608,14 +608,16 @@ def test_simple_batched_generate_with_padding(self): output_sentences[0], " Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for", ) - + self.assertEqual( output_sentences[1], "[PAD][PAD][PAD][PAD] When did the Roman empire 1st fall?\nThe Roman Empire fell in", ) with torch.no_grad(): - logits = self.model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits.to(dtype=torch.float32) + logits = self.model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits.to( + dtype=torch.float32 + ) EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( [ From 9adf85e00b42a7dca00ff13c8646f504dad6c2c4 Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 11 Nov 2024 06:31:49 +0000 Subject: [PATCH 18/73] Fix modular model converter --- .../models/zamba2/modeling_zamba2.py | 110 ++---------------- 1 file changed, 11 insertions(+), 99 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index e78c155a1f2..79413fdf267 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -56,7 +56,7 @@ if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined from mamba_ssm.ops.triton.selective_state_update import selective_state_update else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None @@ -65,10 +65,8 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all( - (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) -) + +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) logger = logging.get_logger(__name__) @@ -271,92 +269,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(DynamicCache): - """ - A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache - (which has a constant shape regardless of seq_len). - - This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` - and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, - while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, - and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. - """ - - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype - self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_expand * config.hidden_size - ssm_state_size = config.mamba_d_state - conv_kernel_size = config.mamba_d_conv - self.n_mamba_heads = config.n_mamba_heads - self.conv_states = [] - self.ssm_states = [] - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - for i in range(config.num_hidden_layers): - self.conv_states += [ - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) - ] - cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) - self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") - - def layer_type_list(config: Zamba2Config): """ Returns list of layer ids containing hybrid layers @@ -1460,7 +1372,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[Zamba2DynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1471,7 +1383,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1685,14 +1597,14 @@ def _init_weights(self, module): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`Zamba2DynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A Zamba2DynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `HybridMambaAttentionDynamicCache` class for more details. + See the `Zamba2DynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1831,7 +1743,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[Zamba2DynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1868,7 +1780,7 @@ def forward( if use_cache and past_key_values is None: logger.warning_once( - "Zamba2 requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "Zamba2 requires an initialized `Zamba2DynamicCache` to return a cache. None was " "provided, so no cache will be returned." ) @@ -2014,7 +1926,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[Zamba2DynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, From 904da4e957074b55fd07cddb07e4d554fa26178e Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 19 Nov 2024 06:25:44 +0000 Subject: [PATCH 19/73] added inheritances in modular, renamed zamba cache --- .../models/zamba/modeling_zamba.py | 60 +- .../models/zamba2/modular_zamba2.py | 585 +++++++++--------- tests/models/zamba/test_modeling_zamba.py | 4 +- 3 files changed, 328 insertions(+), 321 deletions(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 4e97116b563..115e4b19a00 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -113,7 +113,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(DynamicCache): +class ZambaHybridDynamicCache(DynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -131,9 +131,9 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.dtype = dtype self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba - intermediate_size = config.mamba_expand * config.hidden_size - ssm_state_size = config.mamba_d_state - conv_kernel_size = config.mamba_d_conv + self.intermediate_size = config.mamba_expand * config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv self.n_mamba_heads = config.n_mamba_heads self.conv_states = [] self.ssm_states = [] @@ -143,9 +143,9 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): self._buffers = {} for i in range(config.num_hidden_layers): self.conv_states += [ - torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype) ] - cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size) + cache_shape = (batch_size, self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, self.ssm_state_size) self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] if self.layers_block_type[i] == "hybrid": self.transformer_layers.append(i) @@ -196,12 +196,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.to_legacy_cache def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.") @classmethod # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.from_legacy_cache def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.") class ZambaAttention(nn.Module): @@ -249,7 +249,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -327,7 +327,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -417,7 +417,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -569,7 +569,7 @@ def __init__(self, config: ZambaConfig, layer_idx): ) def cuda_kernels_forward( - self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None + self, hidden_states: torch.Tensor, cache_params: ZambaHybridDynamicCache = None, attention_mask=None ): batch_size, seq_len, _ = hidden_states.shape use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1 @@ -665,7 +665,7 @@ def cuda_kernels_forward( contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states - def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None): + def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated linear projection @@ -676,7 +676,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa gate = gate.squeeze(2) gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1) - use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache) + use_cache = isinstance(cache_params, ZambaHybridDynamicCache) # 2. Convolution sequence transformation if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size: if self.training: @@ -758,7 +758,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa ) return contextualized_states - def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None): + def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None): if self.use_fast_kernels: if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type: raise ValueError( @@ -801,7 +801,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -816,7 +816,7 @@ def forward( (see fig. 2 in https://arxiv.org/pdf/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -869,7 +869,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -880,7 +880,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -937,7 +937,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -950,7 +950,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1026,7 +1026,7 @@ class ZambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False _supports_sdpa = False - _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache + _supports_cache_class = True # Note: only supports ZambaHybridDynamicCache _is_stateful = True def _init_weights(self, module): @@ -1120,14 +1120,14 @@ def _check_and_enable_flash_attn_2( config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`ZambaHybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A ZambaHybridDynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `HybridMambaAttentionDynamicCache` class for more details. + See the `ZambaHybridDynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1225,7 +1225,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[ZambaHybridDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1262,7 +1262,7 @@ def forward( if use_cache and past_key_values is None: logger.warning_once( - "Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was " + "Zamba requires an initialized `ZambaHybridDynamicCache` to return a cache. None was " "provided, so no cache will be returned." ) @@ -1409,7 +1409,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[ZambaHybridDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1503,7 +1503,7 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache` + # Overwitten -- has a unique cache type, `ZambaHybridDynamicCache` empty_past_kv = past_key_values is None @@ -1517,7 +1517,7 @@ def prepare_inputs_for_generation( elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - past_key_values = HybridMambaAttentionDynamicCache( + past_key_values = ZambaHybridDynamicCache( self.config, input_ids.shape[0], dtype=self.dtype, device=self.device ) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 3b254481a90..df7b3790c66 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -15,7 +15,7 @@ # limitations under the License. import math from itertools import cycle -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -44,9 +44,15 @@ ZambaMambaDecoderLayer, ZambaModel, ZambaRMSNorm, + ZambaMLP, + ZambaAttentionDecoderLayer, + ZambaHybridLayer, + ZambaPreTrainedModel, + ZambaHybridDynamicCache, repeat_kv, ) - +from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum +from ..llama.modeling_llama import rotate_half, apply_rotary_pos_emb if is_mamba_ssm_available(): from mamba_ssm.ops.triton.selective_state_update import selective_state_update @@ -305,61 +311,61 @@ def layer_type_list(config: Zamba2Config): return ll -# Helper methods for segment sum computation +# # Helper methods for segment sum computation -def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): - """ - Padding x tensor with `pad_size` on the seq_len dim (dim=1) +# def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): +# """ +# Padding x tensor with `pad_size` on the seq_len dim (dim=1) - Assumes that we only have tensors of either size 4 or 3 - """ - pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if input_tensor.ndim == 4 else (0, 0, 0, pad_size, 0, 0) +# Assumes that we only have tensors of either size 4 or 3 +# """ +# pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if input_tensor.ndim == 4 else (0, 0, 0, pad_size, 0, 0) - return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) +# return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) -def reshape_into_chunks(input_tensor, pad_size, chunk_size): - """ - Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and - simultaneously splitting it into chunk sequences. +# def reshape_into_chunks(input_tensor, pad_size, chunk_size): +# """ +# Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and +# simultaneously splitting it into chunk sequences. - Assumes that we only have tensors of either size 4 or 3 - """ - # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] - input_tensor = pad_tensor_by_size(input_tensor, pad_size) - - if input_tensor.ndim == 3: - # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] - return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) - else: - # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] - return input_tensor.reshape( - input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] - ) +# Assumes that we only have tensors of either size 4 or 3 +# """ +# # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] +# input_tensor = pad_tensor_by_size(input_tensor, pad_size) +# if input_tensor.ndim == 3: +# # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] +# return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) +# else: +# # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] +# return input_tensor.reshape( +# input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] +# ) -def segment_sum(input_tensor): - """ - More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. - """ - chunk_size = input_tensor.size(-1) - # 1. expand input tensor to have an additional dimension and repeat along that dimension - # [..., chunk_size] -> [..., chunk_size, chunk_size] - input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) - # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag - mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) - input_tensor = input_tensor.masked_fill(~mask, 0) - # 3. compute actual cumsum - tensor_segsum = torch.cumsum(input_tensor, dim=-2) - - # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) - mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) - tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) - return tensor_segsum - - -class Zamba2DynamicCache(DynamicCache): + +# def segment_sum(input_tensor): +# """ +# More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. +# """ +# chunk_size = input_tensor.size(-1) +# # 1. expand input tensor to have an additional dimension and repeat along that dimension +# # [..., chunk_size] -> [..., chunk_size, chunk_size] +# input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) +# # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag +# mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) +# input_tensor = input_tensor.masked_fill(~mask, 0) +# # 3. compute actual cumsum +# tensor_segsum = torch.cumsum(input_tensor, dim=-2) + +# # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) +# mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) +# tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) +# return tensor_segsum + + +class Zamba2HybridDynamicCache(ZambaHybridDynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -376,17 +382,20 @@ class Zamba2DynamicCache(DynamicCache): def __init__( self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): - self.layers_block_type = config.layers_block_type - self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} - - self.has_previous_state = False - self.dtype = dtype - self.conv_kernel_size = config.mamba_d_conv - self.intermediate_size = int(config.mamba_expand * config.hidden_size) - + super().__init__(config, batch_size, dtype, device) + # self.dtype = dtype + # self.layers_block_type = config.layers_block_type + # self.has_previous_state = False + # self.intermediate_size = int(config.mamba_expand * config.hidden_size) + # ssm_state_size = config.mamba_d_state + # conv_kernel_size = config.mamba_d_conv + # self.n_mamba_heads = config.n_mamba_heads + # self.transformer_layers = [] + # self._modules = {} + # self._parameters = {} + # self._buffers = {} + del self.conv_states + del self.ssm_states self.conv_states = { i: torch.zeros( batch_size, @@ -397,20 +406,17 @@ def __init__( ) for i in range(config.num_hidden_layers) } - num_heads = self.intermediate_size // config.mamba_headdim self.ssm_states = { i: torch.zeros( - batch_size, num_heads, config.mamba_headdim, config.mamba_d_state, device=device, dtype=dtype + batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype ) for i in range(config.num_hidden_layers) } - - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + # for i in range(config.num_hidden_layers): + # if self.layers_block_type[i] == "hybrid": + # self.transformer_layers.append(i) + # self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + # self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor @@ -428,71 +434,57 @@ def reset(self): self.conv_states.zero_() self.ssm_states.zero_() - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Update the cache - if self.key_cache[layer_idx].shape[-1] == 0: - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_states[layer_idx].device - self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - device = self.ssm_states[layer_idx].device - self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # take any layer that contains cache and not empty tensor - layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") - - @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") - - -class MambaRMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states, gate=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - - if gate is not None: - hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - return self.weight * hidden_states.to(input_dtype) + # # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update + # def update( + # self, + # key_states: torch.Tensor, + # value_states: torch.Tensor, + # layer_idx: int, + # cache_kwargs: Optional[Dict[str, Any]] = None, + # ) -> Tuple[torch.Tensor, torch.Tensor]: + # # Update the cache + # if self.key_cache[layer_idx].shape[-1] == 0: + # self.key_cache[layer_idx] = key_states + # self.value_cache[layer_idx] = value_states + # else: + # self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + # self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + # return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache + # def reorder_cache(self, beam_idx: torch.LongTensor): + # """Reorders the cache for beam search, given the selected beam indices.""" + # for layer_idx in range(len(self.key_cache)): + # device = self.key_cache[layer_idx].device + # self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + # device = self.value_cache[layer_idx].device + # self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + # device = self.conv_states[layer_idx].device + # self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + # device = self.ssm_states[layer_idx].device + # self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + # # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length + # def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + # """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # # take any layer that contains cache and not empty tensor + # layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + # if len(self.key_cache) <= layer_idx: + # return 0 + # return self.key_cache[layer_idx].shape[-2] + + + + + ### check the two methods below + # def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + # raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.") + + # @classmethod + # def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + # raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.") ALL_LAYERNORM_LAYERS.append(Zamba2RMSNorm) @@ -527,42 +519,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class Zamba2Attention(ZambaAttention): """ Multi-headed attention from 'Attention Is All You Need' paper. @@ -579,7 +535,7 @@ class Zamba2Attention(ZambaAttention): """ def __init__( - self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks=None, block_id: int = None + self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks: int =None, block_id: int = None ): super().__init__(config, layer_idx) self.num_fwd_mem_blocks = num_fwd_mem_blocks @@ -634,7 +590,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -734,7 +690,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -850,7 +806,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -1022,7 +978,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Optional[Zamba2DynamicCache] = None, + cache_params: Optional[Zamba2HybridDynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -1171,7 +1127,7 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Optional[Zamba2DynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection @@ -1361,7 +1317,7 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2DynamicCache] def forward( self, hidden_states, - cache_params: Optional[Zamba2DynamicCache] = None, + cache_params: Optional[Zamba2HybriDynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -1375,30 +1331,30 @@ def forward( return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) -class Zamba2MLP(nn.Module): +class Zamba2MLP(ZambaMLP): def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int = None): """ This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer is tied, un-tied LoRA modules are added to the up and gate projectors to increase expressivity with a small memory overhead. """ - super().__init__() + super().__init__(config) self.config = config - self.hidden_size = config.hidden_size + # self.hidden_size = config.hidden_size + # self.intermediate_size = config.intermediate_size + # self.act_fn = ACT2FN[config.hidden_act] self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id - self.ffn_intermediate_size = config.intermediate_size - - self.act_fn = ACT2FN[config.hidden_act] def gated_act_fn(x): x = torch.chunk(x, 2, dim=-1) - return self.act_fn(x[0]) * x[1] - self.gated_act_fn = gated_act_fn - self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias=config.add_bias_linear) - self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) + del self.gate_proj + del self.up_proj + del self.down_proj + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) if self.config.use_shared_mlp_lora: self.gate_up_proj_lora_A_list = nn.ModuleList([]) @@ -1406,7 +1362,7 @@ def gated_act_fn(x): for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias=False) - gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias=False) + gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.intermediate_size, bias=False) else: gate_up_proj_lora_A = nn.Identity() gate_up_proj_lora_B = nn.Identity() @@ -1433,76 +1389,78 @@ def forward(self, hidden_state, layer_idx=None): return output -class Zamba2AttentionDecoderLayer(nn.Module): +class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer): def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None): - super().__init__() + super().__init__(config, layer_idx) num_gs = count_mem_blocks_in_config(config) self.block_id = block_id + del self.self_attn + del self.feed_forward self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation]( config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id ) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) - self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) - self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - original_hidden_states: torch.Tensor, - layer_idx: int, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` - original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. - This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The - concatenated tensor is then used as input of the pre-attention RMSNorm - (see fig. 2 in https://arxiv.org/pdf/2405.16712). - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - """ - hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) - hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - layer_idx=layer_idx, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = self.pre_ff_layernorm(hidden_states) - hidden_states = self.feed_forward(hidden_states, layer_idx) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + # self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) + # self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # def forward( + # self, + # hidden_states: torch.Tensor, + # original_hidden_states: torch.Tensor, + # layer_idx: int, + # attention_mask: Optional[torch.Tensor] = None, + # position_ids: Optional[torch.LongTensor] = None, + # past_key_value: Optional[Zamba2HybridDynamicCache] = None, + # output_attentions: Optional[bool] = False, + # use_cache: Optional[bool] = False, + # cache_position: Optional[torch.LongTensor] = None, + # **kwargs, + # ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + # """ + # Args: + # hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` + # original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. + # This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The + # concatenated tensor is then used as input of the pre-attention RMSNorm + # (see fig. 2 in https://arxiv.org/pdf/2405.16712). + # attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + # `(batch, sequence_length)` where padding elements are indicated by 0. + # past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + # output_attentions (`bool`, *optional*): + # Whether or not to return the attentions tensors of all attention layers. See `attentions` under + # returned tensors for more detail. + # use_cache (`bool`, *optional*): + # If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + # (see `past_key_values`). + # cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + # Indices depicting the position of the input sequence tokens in the sequence. + # """ + # hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) + # hidden_states = self.input_layernorm(hidden_states) + # hidden_states, self_attn_weights, present_key_value = self.self_attn( + # hidden_states=hidden_states, + # layer_idx=layer_idx, + # attention_mask=attention_mask, + # position_ids=position_ids, + # past_key_value=past_key_value, + # output_attentions=output_attentions, + # use_cache=use_cache, + # cache_position=cache_position, + # **kwargs, + # ) + + # hidden_states = self.pre_ff_layernorm(hidden_states) + # hidden_states = self.feed_forward(hidden_states, layer_idx) + + # outputs = (hidden_states,) + + # if output_attentions: + # outputs += (self_attn_weights,) + + # if use_cache: + # outputs += (present_key_value,) + + # return outputs class Zamba2MambaDecoderLayer(ZambaMambaDecoderLayer): @@ -1512,12 +1470,13 @@ def __init__(self, config: Zamba2Config, layer_idx: int): self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) -class Zamba2HybridLayer(nn.Module): - def __init__(self, shared_transf: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer): - super().__init__() - self.shared_transf = shared_transf - self.linear = linear - self.mamba_decoder = mamba +class Zamba2HybridLayer(ZambaHybridLayer): + def __init__(self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer): + super().__init__(shared_transformer, linear, mamba) + del self.shared_transf + self.shared_transformer = shared_transformer + # self.linear = linear + # self.mamba_decoder = mamba def forward( self, @@ -1527,7 +1486,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1540,7 +1499,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1551,7 +1510,7 @@ def forward( Indices depicting the position of the input sequence tokens in the sequence. """ - layer_outputs = self.shared_transf( + layer_outputs = self.shared_transformer( hidden_states, original_hidden_states=original_hidden_states, layer_idx=layer_idx, @@ -1608,16 +1567,47 @@ def forward( "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", ZAMBA2_START_DOCSTRING, ) -class Zamba2PreTrainedModel(PreTrainedModel): +# class Zamba2PreTrainedModel(PreTrainedModel): +# config_class = Zamba2Config +# base_model_prefix = "model" +# supports_gradient_checkpointing = True +# _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] +# _skip_keys_device_placement = "past_key_values" +# _supports_flash_attn_2 = True +# _supports_sdpa = False +# _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache +# _is_stateful = True + +# def _init_weights(self, module): +# std = self.config.initializer_range +# if isinstance(module, (nn.Linear, nn.Conv1d)): +# module.weight.data.normal_(mean=0.0, std=std) +# if module.bias is not None: +# module.bias.data.zero_() +# elif isinstance(module, nn.Embedding): +# module.weight.data.normal_(mean=0.0, std=std) +# if module.padding_idx is not None: +# module.weight.data[module.padding_idx].zero_() +# elif isinstance(module, Zamba2MambaMixer): +# module.A_log._no_weight_decay = True +# module.D._no_weight_decay = True + +# dt = torch.exp( +# torch.rand(self.config.n_mamba_heads) +# * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) +# + math.log(self.config.time_step_min) +# ).clamp(min=self.config.time_step_floor) +# # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 +# inv_dt = dt + torch.log(-torch.expm1(-dt)) + +# with torch.no_grad(): +# module.dt_bias.copy_(inv_dt) +# module.dt_bias._no_reinit = True +class Zamba2PreTrainedModel(ZambaPreTrainedModel): config_class = Zamba2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_sdpa = False - _supports_cache_class = True # Note: only supports Zamba2DynamicCache - _is_stateful = True + _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache + _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] def _init_weights(self, module): std = self.config.initializer_range @@ -1645,6 +1635,23 @@ def _init_weights(self, module): module.dt_bias.copy_(inv_dt) module.dt_bias._no_reinit = True + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ): + return super(PreTrainedModel, cls)._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + check_device_map=check_device_map, + hard_check_only=hard_check_only, + ) + ZAMBA2_INPUTS_DOCSTRING = r""" Args: @@ -1681,14 +1688,14 @@ def _init_weights(self, module): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`Zamba2DynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A Zamba2DynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`Zamba2HybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A Zamba2HybridDynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `Zamba2DynamicCache` class for more details. + See the `Zamba2HybridDynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1754,14 +1761,14 @@ def __init__(self, config: Zamba2Config): if config.num_mem_blocks * len(layer_type_list(config)) > 1: prefix_name = f"layers.{layer_id}." tied_keys = [ - "shared_transf.self_attn.q_proj.weight", - "shared_transf.self_attn.k_proj.weight", - "shared_transf.self_attn.v_proj.weight", - "shared_transf.self_attn.o_proj.weight", - "shared_transf.feed_forward.gate_up_proj.weight", - "shared_transf.feed_forward.down_proj.weight", - "shared_transf.input_layernorm.weight", - "shared_transf.pre_ff_layernorm.weight", + "shared_transformer.self_attn.q_proj.weight", + "shared_transformer.self_attn.k_proj.weight", + "shared_transformer.self_attn.v_proj.weight", + "shared_transformer.self_attn.o_proj.weight", + "shared_transformer.feed_forward.gate_up_proj.weight", + "shared_transformer.feed_forward.down_proj.weight", + "shared_transformer.input_layernorm.weight", + "shared_transformer.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] if config.use_shared_mlp_lora: @@ -1770,10 +1777,10 @@ def __init__(self, config: Zamba2Config): for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: tied_keys_lora.append( - "shared_transf.feed_forward.gate_up_proj_lora_A_list." + str(lora_id) + ".weight" + "shared_transformer.feed_forward.gate_up_proj_lora_A_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.feed_forward.gate_up_proj_lora_B_list." + str(lora_id) + ".weight" + "shared_transformer.feed_forward.gate_up_proj_lora_B_list." + str(lora_id) + ".weight" ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] @@ -1783,22 +1790,22 @@ def __init__(self, config: Zamba2Config): for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: tied_keys_lora.append( - "shared_transf.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] @@ -1827,7 +1834,7 @@ def __init__(self, config: Zamba2Config): self.post_init() # Adapted from transformers.models.zamba.modeling_zamba.ZambaForCausalLM.prepare_inputs_for_generation - # with `Zamba2DynamicCache` -> `Zamba2DynamicCache` + # with `Zamba2HybridDynamicCache` -> `Zamba2HybridDynamicCache` def prepare_inputs_for_generation( self, input_ids, @@ -1839,7 +1846,7 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # Overwitten -- has a unique cache type, `Zamba2DynamicCache` + # Overwitten -- has a unique cache type, `Zamba2HybridDynamicCache` empty_past_kv = past_key_values is None @@ -1853,7 +1860,7 @@ def prepare_inputs_for_generation( elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - past_key_values = Zamba2DynamicCache(self.config, input_ids.shape[0], dtype=self.dtype, device=self.device) + past_key_values = Zamba2HybridDynamicCache(self.config, input_ids.shape[0], dtype=self.dtype, device=self.device) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index a6dd516f98a..20743c796f0 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -46,7 +46,7 @@ ZambaModel, ) from transformers.models.zamba.modeling_zamba import ( - HybridMambaAttentionDynamicCache, + ZambaHybridDynamicCache, ) @@ -215,7 +215,7 @@ def create_and_check_decoder_model_past_large_inputs( # first forward pass # Attention: Zamba needs the cache to be initialized to return a cache! - past_key_values = HybridMambaAttentionDynamicCache( + past_key_values = ZambaHybridDynamicCache( config, input_ids.shape[0], model.dtype, device=model.device ) outputs = model( From 0be27d74b01b7731458e8f5d7a66de2a4a75491d Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 19 Nov 2024 09:22:08 +0000 Subject: [PATCH 20/73] modular rebase --- .../models/zamba2/modular_zamba2.py | 4 +- utils/modular_model_converter.py | 1656 ++++++++++------- 2 files changed, 983 insertions(+), 677 deletions(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index df7b3790c66..1d1888e27c3 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -535,7 +535,7 @@ class Zamba2Attention(ZambaAttention): """ def __init__( - self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks: int =None, block_id: int = None + self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks: int = None, block_id: int = None ): super().__init__(config, layer_idx) self.num_fwd_mem_blocks = num_fwd_mem_blocks @@ -1317,7 +1317,7 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic def forward( self, hidden_states, - cache_params: Optional[Zamba2HybriDynamicCache] = None, + cache_params: Optional[Zamba2HybridDynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index bda143c2577..b1dfa18a7a9 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -17,13 +17,14 @@ import importlib import os import re +from abc import ABC, abstractmethod from collections import defaultdict, deque -from typing import Dict, List, Optional, Set +from typing import Dict, Set import libcst as cst from check_copies import run_ruff from create_dependency_mapping import find_priority_list -from libcst import ClassDef, CSTTransformer, CSTVisitor +from libcst import ClassDef, CSTVisitor from libcst import matchers as m from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider @@ -34,13 +35,6 @@ logger = logging.get_logger(__name__) -# This is used to avoid overwriting these top-level assignments even if they are in the dependency graph. Otherwise, the -# value from the dependency is used, then mapped to current name convention, resulting in wrong value. -# The corresponding mapped value is used to define the file target for the assignment -ASSIGNMENTS_TO_KEEP = { - "_CHECKPOINT_FOR_DOC": "modeling", -} - AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from {relative_path}. # Do NOT edit this file manually as any edits will be overwritten by the generation of @@ -61,137 +55,23 @@ def get_module_source_from_name(module_name: str) -> str: return source_code -class ClassFinder(CSTVisitor): - """A visitor class which analyses a module, creating a mapping of dependencies between classes and functions. - For example if the visited code has - ```python3 - def init_value(): return 1 - - class LlamaModel(PreTrainedModel): - def __init__(self): - super().__init__(self) - self.value = init_value() - ``` - then the `class_dependency_mapping` should be: `{"LlamaModel":["PreTrainedModel","init_value"], "init_value":[]} - - The dependency mapping is updated via the `visit_Name`, `visit_Arg` and `visit_Decorator`. This is very broad, and by - checking the parent node, or the scope of a `cst.Name` or `cst.Arg` or `cst.Decorator` we are able to map the - dependence parent -> child. - - When visiting such nodes, we update the dependency of the parent node, to take into account the visited node. - - All `visit_XXX` correspond to the code executed when vising the cst.Node of type XXX. - """ - - METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) +def preserve_case_replace(text, patterns: dict, default_name: str): + # Create a regex pattern to match all variations + regex_pattern = "|".join(re.escape(key) for key in patterns.keys()) + compiled_regex = re.compile(regex_pattern, re.IGNORECASE) - def __init__(self, python_module: cst.Module): - # fmt: off - self.python_module: cst.Module = python_module # original cst.Module being visited - self.classes: Dict[str, cst.ClassDef] = {} # stores a mapping from classname to the cst.Node - self.imports = {} # stores all import statements - self.function_def = {} # stores global scope function definition - self.assignments = {} # LLAMA_DOCSTRING - self.class_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"] - self.first_lvl_dependency_mapping = {} # "LlamaModel":["LlamaDecoderLayer, "LlamaRMSNorm", "LlamaPreTrainedModel"], "LlamaDecoderLayer":["LlamaAttention","Llama"] - # fmt: on - - def _update_class_dependency(self, name, value): - """Update the dependency mapping for `name` with `value` by appending the previous - dependencies to the new `value`. - """ - dep = set(self.first_lvl_dependency_mapping.get(name, set())) | set({value}) - self.first_lvl_dependency_mapping[name] = dep - - dep = set(self.class_dependency_mapping.get(value, set())) - dep |= set(self.class_dependency_mapping.get(name, {})) | set({value}) - self.class_dependency_mapping[name] = dep - - def visit_ClassDef(self, node: ClassDef) -> None: - """We don't have non global scope class defs in transformers. Here we add the inheritance dependencies""" - self.classes[node.name.value] = node - for k in node.bases: # deal with inheritance - base_name = self.python_module.code_for_node(k) - self._update_class_dependency(node.name.value, base_name) + def replace(match): + word = match.group(0) + result = patterns.get(word, default_name) + return result - def visit_SimpleStatementLine(self, node): - """ - Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT' and all import statements - are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. - """ - if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])) and m.matches( - self.get_metadata(cst.metadata.ParentNodeProvider, node), m.Module() - ): - left_hand_side = node.body[0].targets[0].target - if hasattr(left_hand_side, "value"): - if left_hand_side.value not in ASSIGNMENTS_TO_KEEP.keys(): - self.assignments[left_hand_side.value] = node - else: - for idx, target in enumerate(list(left_hand_side.elements)): - if target.value.value not in ASSIGNMENTS_TO_KEEP.keys(): - self.assignments[target.value.value] = node.body[0].value.elements[idx].value - if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): - self.imports[node.body[0].names] = node + return compiled_regex.sub(replace, text) - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.function_def[node.name.value] = node - def leave_If(self, node): - for stmt in node.body.body: - if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): - self.imports[stmt.body[0].names] = node - - def leave_Name(self, node): - if node.value in self.classes.keys() | self.assignments.keys() | self.function_def.keys(): - parent = self.get_metadata(cst.metadata.ScopeProvider, node) - if not isinstance(parent, cst.metadata.scope_provider.GlobalScope): - self._update_class_dependency(parent._name_prefix.split(".")[0], node.value) - - def leave_Arg(self, node): - if m.matches(node.value, m.Name()): - parent = self.get_metadata(ParentNodeProvider, node) - if m.matches(parent, m.ClassDef()) and parent.bases: - self._update_class_dependency(parent.name.value, node.value.value) - - def leave_Dict(self, node): - parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent, m.Assign(targets=[m.AssignTarget()])): - name = parent.targets[0].target.value - if name in self.assignments: - for k in node.elements: - dep_name = k.value.value - if dep_name in self.classes: - self._update_class_dependency(name, dep_name) - - def leave_Decorator(self, node): - if hasattr(node.decorator, "args"): - for k in node.decorator.args: - if m.matches(k.value, m.Call(func=m.Attribute(value=m.Name()))): # and k.value.func.value.value: - if k.value.func.value.value not in self.assignments: - raise ValueError( - f"We detected a call to {k.value.func.value.value}, but it was not assigned. See the list of assigments {self.assignments.keys()}" - ) - parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) - scope = self.get_metadata(cst.metadata.ScopeProvider, node) - name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value - self._update_class_dependency(name, k.value.func.value.value) - elif m.matches(k, m.Arg(value=m.Name())) and k.value.value in self.assignments: - parent = self.get_metadata(cst.metadata.ParentNodeProvider, node) - scope = self.get_metadata(cst.metadata.ScopeProvider, node) - name = scope._name_prefix.split(".")[0] if scope._name_prefix != "" else parent.name.value - self._update_class_dependency(name, k.value.value) - - def leave_Module(self, node): - """When leaving the module, we store the position of each global scoped node (Assigns, function def and class def) - to allow sorting the dependencies based on their position in the code. We use the PositionProvider metadata wrapper for this. - """ - self.global_nodes = {**self.assignments, **self.classes, **self.function_def} - # now sort the class dependency_mapping based on the position of the nodes - self.class_start_line = {} - for id, node in self.global_nodes.items(): - self.class_start_line[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line +def convert_to_camelcase(text, old_name: str, default_old_name: str): + # Regex pattern to match consecutive uppercase letters and lowercase the first set + result = re.sub(rf"^({old_name})(?=[a-z]+)", lambda m: default_old_name, text, flags=re.IGNORECASE, count=1) + return result class ReplaceNameTransformer(m.MatcherDecoratableTransformer): @@ -210,8 +90,6 @@ def __init__( new_name, given_old_name=None, given_new_name=None, - old_class_name: str = None, - new_class_name: str = None, ): super().__init__() self.old_name = old_name @@ -232,70 +110,17 @@ def __init__( self.default_old_name = CONFIG_MAPPING_NAMES[self.old_name].replace("Config", "") if self.default_old_name.isupper(): self.default_old_name = self.default_old_name.capitalize() - if new_class_name is not None and old_class_name is not None and old_class_name not in self.patterns: - # In last recourse, when the suffix of the new class is not the same as the old class, - # and if the old and new classes start with the default name, we keep the default class name - # and replace the old suffix with the new one. - # Useful when we have a class like `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration` - # where a model extends another model, but is used for a different task. - if old_class_name.startswith(self.default_old_name) and new_class_name.startswith(self.default_name): - self.patterns[old_class_name[len(self.default_old_name) :]] = new_class_name[len(self.default_name) :] - - def preserve_case_replace(self, text): - # Create a regex pattern to match all variations - regex_pattern = "|".join(re.escape(key) for key in self.patterns.keys()) - compiled_regex = re.compile(regex_pattern, re.IGNORECASE) - - def replace(match): - word = match.group(0) - result = self.patterns.get(word, self.default_name) - return result - - return compiled_regex.sub(replace, text) - - def convert_to_camelcase(self, text): - # Regex pattern to match consecutive uppercase letters and lowercase the first set - result = re.sub( - rf"^({self.old_name})(?=[a-z]+)", lambda m: self.default_old_name, text, flags=re.IGNORECASE, count=1 - ) - return result @m.leave(m.Name() | m.SimpleString() | m.Comment()) def replace_name(self, original_node, updated_node): if re.findall(r"# Copied from", updated_node.value): return cst.RemoveFromParent() - update = self.preserve_case_replace(updated_node.value) + update = preserve_case_replace(updated_node.value, self.patterns, self.default_name) return updated_node.with_changes(value=update) def leave_ClassDef(self, original_node, updated_node): - return updated_node.with_changes(name=cst.Name(self.convert_to_camelcase(updated_node.name.value))) - - -def find_classes_in_file( - module: cst.Module, - old_id="llama", - new_id="gemma", - given_old_name=None, - given_new_name=None, - old_class_name=None, - new_class_name=None, -): - """Helper function to rename and then parse a source file using the ClassFinder""" - transformer = ReplaceNameTransformer( - old_id, - new_id, - given_old_name=given_old_name, - given_new_name=given_new_name, - old_class_name=old_class_name, - new_class_name=new_class_name, - ) - new_module = module.visit(transformer) - - wrapper = MetadataWrapper(new_module) - - class_finder = ClassFinder(new_module) - wrapper.visit(class_finder) - return class_finder + new_name = convert_to_camelcase(updated_node.name.value, self.old_name, self.default_old_name) + return updated_node.with_changes(name=cst.Name(new_name)) DOCSTRING_NODE = m.SimpleStatementLine( @@ -412,13 +237,12 @@ def merge_docstrings(original_docstring, updated_docstring): class SuperTransformer(cst.CSTTransformer): METADATA_DEPENDENCIES = (ParentNodeProvider,) - def __init__(self, python_module: cst.Module, original_methods, updated_methods, class_name="", all_bases=None): + def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None): self.python_module = python_module self.original_methods = original_methods self.updated_methods = updated_methods self.all_assign_target = {} self.deleted_targets = {} # child node can delete some arguments - self.class_name = class_name self.all_bases = all_bases or [] self.transformer = ReplaceMethodCallTransformer(set(self.all_bases)) @@ -437,7 +261,6 @@ def update_body(self, existing_body, new_statements): if m.matches(node, m.SimpleStatementLine(body=[m.Del()])): target = self.python_module.code_for_node(node.body[0].target) self.deleted_targets[target] = node - continue for stmt in existing_body: if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): @@ -447,6 +270,9 @@ def update_body(self, existing_body, new_statements): continue if target in self.all_assign_target: stmt = self.all_assign_target[target] + # Skip the docstring (will be added later on, at the beginning) + elif m.matches(stmt, DOCSTRING_NODE): + continue comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() deduplicated_new_body.append(stmt) @@ -456,17 +282,47 @@ def update_body(self, existing_body, new_statements): code = self.python_module.code_for_node(node) comment_less_code = re.sub(r"#.*", "", code).strip() comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() - if ( - node not in deduplicated_new_body - and "super().__init__" not in comment_less_code - and comment_less_code not in existing_nodes - ): + if node not in deduplicated_new_body and comment_less_code not in existing_nodes: if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])): - # HACK here to fix the pos_init() that has to be last we kinda do this. - deduplicated_new_body = deduplicated_new_body[:-1] + [node] + deduplicated_new_body[-1:] + deduplicated_new_body.append(node) existing_nodes.add(comment_less_code) + + deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body) + return deduplicated_new_body + def _fix_post_init_location(self, new_body: list[cst.CSTNode]): + """Fix the location of the `post_init()` in the new body, if we added statements after the call to + `super()` (it needs to be the very last statement called)""" + # Fix the post_init() that has to be last + for i, node in enumerate(new_body): + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if "self.post_init(" in comment_less_code and i < len(new_body) - 1: + # Remove it and add it again at the end + new_body.pop(i) + new_body.append(node) + break + return new_body + + def _fix_init_location(self, new_body): + """Fix the location of the `super().__init__()` in the new body, if we had new statements before it.""" + start_index = 0 + for i, node in enumerate(new_body): + if m.matches(node, DOCSTRING_NODE) and i == start_index: + start_index += 1 + continue + code = self.python_module.code_for_node(node) + comment_less_code = re.sub(r"#.*", "", code).strip() + comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() + if "super().__init__" in comment_less_code and i > start_index: + # Remove it and add it again at the top after the docstrings + node = new_body.pop(i) + new_body = new_body[:start_index] + [node] + new_body[start_index:] + break + return new_body + def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: """Updates the body of the input `node`'s `func_name` function by replacing calls to super().func_name() with the source code of the parent class' `func_name`. @@ -479,10 +335,11 @@ def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CS new_body = [] has_super_call = False - for expr in node.body: + for i, expr in enumerate(node.body): if is_call_to_super(expr, func_name): has_super_call = True - new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body)) + new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :])) + new_body = self._fix_init_location(new_body) else: expr = expr.visit(self.transformer) if m.matches(expr, DOCSTRING_NODE): @@ -524,11 +381,463 @@ def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> c return updated_node -def replace_call_to_super( - class_finder: ClassFinder, updated_node: cst.ClassDef, class_name: str, all_bases: List[str] -): +def find_all_dependencies( + dependency_mapping: Dict[str, set], + start_entity: str | None = None, + initial_dependencies: set | None = None, + initial_checked_dependencies: set | None = None, + return_parent: bool = False, +) -> list | set: + """Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of + BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. + + Args: + dependency_mapping (`Dict[str, set]`): + A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names, + a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called + in `foo`'s definition. + start_entity (str | None, *optional*): + A key of `dependency_mapping`, indicating from which entity to start the search. + initial_dependencies (set | None, *optional*): + If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue + from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. + initial_checked_dependencies (set | None, *optional*): + If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. + return_parent (bool, *optional*): + If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note + that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs. + Returns: + A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`. + + Example: + Given the following structure in the `modular_xxx.py` file: + ``` + def foo1(): + pass + + def foo2(): + pass + + def bar(): + foo1() + + def foobar(): + bar() + foo2() + + class MyLayer(SomeOtherModelLayer): + def forward(...): + foobar() + ``` + and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: + ``` + dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} + find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True) + >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] + ``` + That is, all the functions needed (and potentially their immediate parent) so that the function to be added + in MyLayer (`foobar`) can work correctly. + """ + if initial_dependencies is None and start_entity is not None: + initial_dependencies = dependency_mapping[start_entity] + if initial_checked_dependencies is None: + initial_checked_dependencies = set() + + dependency_queue = deque(initial_dependencies) + all_dependencies = set() + all_dependencies_with_parent = [] + checked_dependencies = set(initial_checked_dependencies) + parents = {initial_dep: start_entity for initial_dep in initial_dependencies} + while len(dependency_queue) > 0: + # Pick element to visit + current = dependency_queue.popleft() + if current not in checked_dependencies: + # Add the dependencies + all_dependencies.add(current) + all_dependencies_with_parent += [(current, parents[current])] + if current in dependency_mapping.keys(): + # Update dependency queue + dependency_queue.extend(dependency_mapping[current]) + parents.update({dep: current for dep in dependency_mapping[current]}) + # add visited node to the list + checked_dependencies.add(current) + + if not return_parent: + return all_dependencies + # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) + return all_dependencies_with_parent + + +# These top-level variables will always use the value in the `modular_xxx.py` file +ASSIGNMENTS_TO_KEEP = { + "_CHECKPOINT_FOR_DOC", +} + + +class ClassDependencyMapper(CSTVisitor): + """A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of + `global_names`. + """ + + def __init__( + self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None + ): + super().__init__() + self.class_name = class_name + self.dependencies = set() + self.global_names = global_names + self.objects_imported_from_modeling = ( + set() if objects_imported_from_modeling is None else objects_imported_from_modeling + ) + + def visit_Name(self, node): + if ( + node.value != self.class_name + and node.value in self.global_names + and node.value not in self.objects_imported_from_modeling + ): + self.dependencies.add(node.value) + + +def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set: + """Create immediate dependencies for a class node based on the `global_names`.""" + temp_module = cst.Module(body=[node]) + visitor = ClassDependencyMapper(node.name.value, global_names) + temp_module.visit(visitor) + return visitor.dependencies + + +def augmented_dependencies_for_class_node( + node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None +) -> set: + """Create augmented dependencies for a class node based on a `mapper`. + Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. + """ + temp_module = cst.Module(body=[node]) + visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling) + temp_module.visit(visitor) + return mapper.augment_dependencies(visitor.dependencies) + + +# All the potential file types to create +ALL_FILE_TYPES = ( + "modeling", + "configuration", + "tokenization", + "processing", + "image_processing", + "feature_extractor", +) + + +class ModuleMapper(CSTVisitor, ABC): + """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments. + Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in + `self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`). + It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the + modeling files that will be visited. """ - Given the `class_name`, the `updated_node`'s call to super are unpacked. + + METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) + + def __init__(self, python_module: cst.Module): + # fmt: off + self.python_module: cst.Module = python_module # original cst.Module being visited + self.classes: Dict[str, cst.ClassDef] = {} # mapping from class names to Nodes (it will be ordered by default!!) + self.imports = [] # stores all import statements + self.functions: Dict[str, cst.FunctionDef] = {} # mapping of global scope function names to Nodes + self.object_dependency_mapping = defaultdict(set) # immediate function/assignment dependency mapping (i.e. dependencies immediately in the function/assignment definition) + self.assignments: Dict[str, cst.SimpleStatementLine] = {} # mapping of global assignments names to Nodes + self.current_function = None # this keeps track of the current module-scope function + self.current_assignment = None # this keeps track of the current module-scope assignment + # this keeps track of objects imported from modeling files (`from .configuration import Config`) -> `Config` should not be a dependency + self.objects_imported_from_modeling = set() + # regex pattern joining every possible file type + self.match_patterns = "|".join(ALL_FILE_TYPES) + # fmt: on + + def visit_ImportFrom(self, node): + """This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have + `from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs + to be added (because it will be part of the imports)""" + import_module = self.python_module.code_for_node(node.module) + import_statement = "." * len(node.relative) + import_module + if re.search(rf"^\.({self.match_patterns})_.*", import_statement): + for imported_object in node.names: + # If an alias is present, we record it and not the original name + if imported_object.evaluated_alias is not None: + self.objects_imported_from_modeling.add(imported_object.evaluated_alias) + else: + self.objects_imported_from_modeling.add(imported_object.evaluated_name) + + def visit_SimpleStatementLine(self, node): + """ + Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements + are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. + """ + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + simple_top_level_assign_structure = m.SimpleStatementLine( + body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] + ) + if m.matches(parent_node, m.Module()): + if m.matches(node, simple_top_level_assign_structure): + left_hand_side = node.body[0].targets[0].target.value + self.current_assignment = left_hand_side + self.assignments[left_hand_side] = node + elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): + self.imports.append(node) + + def leave_SimpleStatementLine(self, node): + # No need to check for the parent here -> everytime we exit one, it should be None anyway independently of where the + # SimpleStatement is located + self.current_assignment = None + + def visit_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.current_function = node.name.value + self.functions[node.name.value] = node + + def leave_FunctionDef(self, node): + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + if m.matches(parent_node, m.Module()): + self.current_function = None + + def visit_If(self, node): + for stmt in node.body.body: + if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): + self.imports.append(node) + + def visit_ClassDef(self, node: ClassDef) -> None: + """Record class nodes to create their dependencies at the end.""" + self.classes[node.name.value] = node + + def visit_Name(self, node: cst.Call): + """This is used to create a mapping from module-scope functions and assignments to objects used inside them.""" + if self.current_function is not None: + self.object_dependency_mapping[self.current_function].add(node.value) + if self.current_assignment is not None: + self.object_dependency_mapping[self.current_assignment].add(node.value) + + def leave_Module(self, node): + """When leaving the module, we store the position of each global scoped node to allow sorting the dependencies + based on their position in the code later. We use the PositionProvider metadata wrapper for this. + We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in + `self.global_nodes`. + """ + # assign all nodes + self.global_nodes = {**self.assignments, **self.classes, **self.functions} + # now sort the class dependency_mapping based on the position of the nodes + self.start_lines = {} + for id, node in self.global_nodes.items(): + self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line + + # Since we added every Name as part of `self.object_dependency_mapping`, we now remove those that + # are not part of the recorded objects (i.e. built-in variables, imports, etc) + global_objects = set(self.global_nodes.keys()) + for object_name, dependencies in self.object_dependency_mapping.items(): + self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} + + def _compute_recursive_object_dependencies(self) -> dict[str, set]: + """Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the + following file: + ``` + def foo(): + pass + + def bar(): + foo() + + def test(): + bar() + ``` + this visitor can only record immediate dependencies, i.e. it will record the following + `self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create + the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`. + """ + recursive_dependencies = {} + for object_name in self.object_dependency_mapping.keys(): + all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name) + recursive_dependencies[object_name] = all_dependencies + return recursive_dependencies + + def augment_dependencies(self, dependencies: set[str]) -> set[str]: + """For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and + **assignments** present in the `dependencies`. + """ + new_dependencies = dependencies.copy() + # Go through the set of dependencies + for dep in tuple(dependencies): + if dep in self.object_recursive_dependency_mapping.keys(): + new_dependencies.update(self.object_recursive_dependency_mapping[dep]) + return new_dependencies + + def compute_class_dependencies(self): + """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies.""" + self.class_dependency_mapping = {} + for class_name, class_node in self.classes.items(): + dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) + # Correctly augment class dependencies with all needed objects + self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies) + + @abstractmethod + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: + raise NotImplementedError + + +class ModelFileMapper(ModuleMapper): + """A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file + in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. + For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes + care of correctly merging dependencies, then finalizes all dependency graph computations. + Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified. + For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies + of the modeling files as well. + """ + + def __init__(self, python_module: cst.Module): + super().__init__(python_module) + + def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]: + """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that + will be created based on the modular. + """ + relative_order = {} + idx = 0 + classes = sorted( + [dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x] + ) + # This is because for merged dependencies, we only have relative order in the other visited file, so we need + # to track dependency order relative to a given class + if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): + raise ValueError("Cannot correctly find the relative order of the dependencies.") + + remaining_dependencies = missing_dependencies.copy() + + # Start by tracking relative order class by class + for class_name in classes: + class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies) + original_dependencies = [] + merged_dependencies = [] + # We need to differentiate between nodes that were already present (we can get relative order globally) and + # nodes that were merged (we can get relative order only relative to the class the dependencies relate to) + for class_dep in class_dependencies: + if class_dep in self.start_lines: + original_dependencies.append(class_dep) + else: + merged_dependencies.append(class_dep) + # Sort both list according to the order in their respective file + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) + + # Add all original node first, then merged ones + for dep in original_dependencies + merged_dependencies: + remaining_dependencies.remove(dep) + relative_order[dep] = idx + idx += 1 + # Add the class itself + remaining_dependencies.remove(class_name) + relative_order[class_name] = idx + idx += 1 + + # Now add what still remains + remaining_dependencies = tuple(remaining_dependencies) + original_dependencies = [] + merged_dependencies = [] + for dep in remaining_dependencies: + if dep in self.modular_file_start_lines: + merged_dependencies.append(dep) + else: + original_dependencies.append(dep) + # Sort both list according to the order in their respective file + original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) + + # Add all original node first, then merged ones + for dep in original_dependencies + merged_dependencies: + relative_order[dep] = idx + idx += 1 + + return relative_order + + def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]): + """Update the global nodes and function dependency mapping with those from the modular file. + + Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies + instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). + """ + # Add/overwrite all needed function nodes and dependencies + self.functions.update(functions) + self.object_dependency_mapping.update( + {obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()} + ) + + def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): + """Update the global nodes with the assignment from the modular file. + + Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it is + in `ASSIGNMENTS_TO_KEEP`. Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the + big docstrings. + """ + for assignment, node in assignments.items(): + if assignment in ASSIGNMENTS_TO_KEEP or assignment not in self.assignments: + self.assignments[assignment] = node + if assignment in object_mapping: + self.object_dependency_mapping[assignment] = object_mapping[assignment] + + def _merge_classes(self, classes: dict[str, cst.CSTNode]): + """Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and + are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined + classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we + do not add the new classes to `self.classes`, but only to `global_nodes`. + """ + # Add/overwrite all needed function nodes and dependencies + self.global_nodes.update( + { + name: node + for name, node in classes.items() + if name not in self.classes and name not in self.objects_imported_from_modeling + } + ) + + def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines): + """Merge classes, functions and assignments from the modular definitions into the current module file, + then record the relative order of all nodes. + Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the + merge with other files dependencies. + """ + self._merge_functions(functions, object_mapping) + self._merge_assignments(assignments, object_mapping) + self._merge_classes(classes) + self.modular_file_start_lines = start_lines + + # Correctly re-set the global nodes at this point + self.global_nodes.update(self.functions) + self.global_nodes.update(self.assignments) + # Create the global mapping of recursive dependencies for functions and assignments + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() + + @classmethod + def visit_and_merge_dependencies( + cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines + ) -> "ModelFileMapper": + wrapper = MetadataWrapper(module) + mapper = cls(module) + wrapper.visit(mapper) + # Merge dependencies + mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines) + # Create the class dependencies graph + mapper.compute_class_dependencies() + return mapper + + +def replace_class_node(mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str): + """ + Replace a class node which inherits from another modeling class. This function works in the following way: + - start from the base class node of the inherited class (a cst.Node) + - replace all methods of the base node with the methods defined in the child class + - append all new methods defined in the child class + - replace all calls to super() with the unravelled code | ```python | | ```python | class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module): @@ -547,14 +856,15 @@ def replace_call_to_super( | self.post_init() | ``` """ - original_node = class_finder.classes[class_name] + all_bases = [k.value.value for k in class_node.bases] + + original_node = mapper.classes[renamed_super_class] original_methods = { - f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f + f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in original_node.body.body } updated_methods = { - f.name.value if hasattr(f, "name") else class_finder.python_module.code_for_node(f): f - for f in updated_node.body.body + f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body } end_meth = [] @@ -562,7 +872,7 @@ def replace_call_to_super( docstring_node = [] # Iterate directly from node.body as there can be property/setters with same names which are overwritten when we use a dict for func in original_node.body.body: - name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func) + name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None: new_params = updated_methods[name].params # Replace the method in the replacement class, preserving decorators @@ -573,19 +883,23 @@ def replace_call_to_super( new_params = new_params.with_changes( params=list(parent_params.values()), star_kwarg=func.params.star_kwarg ) + # Keep decorators in `modular_xxx.py` if any, else original decorators + new_decorators = ( + updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators + ) if not re.match( r"\ndef .*\(.*\):\n raise.*Error\(.*", - class_finder.python_module.code_for_node(updated_methods[name]), + mapper.python_module.code_for_node(updated_methods[name]), ): - func = func.with_changes(body=updated_methods[name].body, params=new_params) + func = func.with_changes(body=updated_methods[name].body, params=new_params, decorators=new_decorators) else: continue if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): - target = class_finder.python_module.code_for_node(func.body[0].targets[0]) + target = mapper.python_module.code_for_node(func.body[0].targets[0]) assign_targets[target] = func elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): - target = class_finder.python_module.code_for_node(func.body[0].target) + target = mapper.python_module.code_for_node(func.body[0].target) assign_targets[target] = func elif m.matches(func, DOCSTRING_NODE): docstring_node = [func] @@ -593,8 +907,8 @@ def replace_call_to_super( end_meth.append(func) # Port new methods that are defined only in modular-file and append at the end - for func in updated_node.body.body: - name = func.name.value if hasattr(func, "name") else class_finder.python_module.code_for_node(func) + for func in class_node.body.body: + name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) if m.matches(func, DOCSTRING_NODE): # This processes the docstring of the class! # Extract the original docstring updated_docstring = func.body[0].value.value @@ -608,22 +922,28 @@ def replace_call_to_super( end_meth.append(func) if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): # TODO we only use single assign might cause issues - target = class_finder.python_module.code_for_node(func.body[0].targets[0]) + target = mapper.python_module.code_for_node(func.body[0].targets[0]) assign_targets[target] = func if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): - target = class_finder.python_module.code_for_node(func.body[0].target) + target = mapper.python_module.code_for_node(func.body[0].target) assign_targets[target] = func end_meth = docstring_node + list(assign_targets.values()) + end_meth + # Replace the calls to `super()` with the unrolled code result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) temp_module = cst.Module(body=[result_node]) new_module = MetadataWrapper(temp_module) new_replacement_class = new_module.visit( - SuperTransformer(temp_module, original_methods, updated_methods, class_name, all_bases) + SuperTransformer(temp_module, original_methods, updated_methods, all_bases) ) new_replacement_body = new_replacement_class.body[0].body # get the indented block - return original_node.with_changes(body=new_replacement_body) + # Use decorators redefined in `modular_xxx.py` if any + new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators + # Always use the new name of the class (in case we use e.g. `ColPaliForRetrieval` inheriting from `PaliGemmaForConditionalGeneration`) + name = class_node.name + + return original_node.with_changes(body=new_replacement_body, decorators=new_decorators, name=name) TYPE_TO_FILE_TYPE = { @@ -632,498 +952,483 @@ def replace_call_to_super( "Processor": "processing", "ImageProcessor": "image_processing", "FeatureExtractor": "feature_extractor", + "ProcessorKwargs": "processing", + "ImagesKwargs": "processing", + "TextKwargs": "processing", } -def get_new_part(class_name, base_class): +def find_file_type(class_name: str) -> str: + """Based on a class name, find the file type corresponding to the class. + If the class name is `LlamaConfig` it will return `configuration`. + The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling` """ - When `MyClassNameAttention` inherits from `MistralAttention`, we need - to process the name to properly find dependencies. - - Here we take what is the same (Attention) and what is different - when finding the dependencies. - """ - common_suffix_len = 0 - for i in range(1, min(len(class_name), len(base_class)) + 1): - if class_name[-i] == base_class[-i]: - common_suffix_len += 1 - else: - break - - if common_suffix_len > 0: - new_part = class_name[:-common_suffix_len] + match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) + match = re.search(rf"({match_pattern})$", class_name) + if match: + file_type = TYPE_TO_FILE_TYPE[match.group(1)] else: - new_part = class_name + file_type = "modeling" + return file_type - # Convert the remaining new part to snake_case - snake_case = re.sub(r"(? 0: + new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) + imports_to_keep.append(new_node) - def foobar(): - bar() - foo2() - class MyLayer(SomeOtherModelLayer): - def forward(...): - foobar() - ``` - and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: - ``` - dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} - find_all_dependencies('foobar', dependency_mapping) - >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] - ``` - That is, all the functions needed (and their immediate parent) so that the function to be added in MyLayer (`foobar`) can - work correctly. +def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: + """Get all the imports needed in the `body`, from the list of `all_imports`. + `body` is a dict with the following structure `{str: {"insert_idx": int, "node": cst.CSTNode}}`. + Note: we need to use `isinstance` on scope assignements, m.matches apparently does not work here yet! """ - all_dependencies = deque(dependency_mapping[function]) - all_dependencies_with_parent = [(dep, function) for dep in dependency_mapping[function]] - checked_dependencies = set(function) - while len(all_dependencies) > 0: - # Pick element to visit - parent = all_dependencies.popleft() - if parent not in checked_dependencies: - # Update dependencies - all_dependencies.extend(dependency_mapping[parent]) - all_dependencies_with_parent += [(dependency, parent) for dependency in dependency_mapping[parent]] - # add visited node to the list - checked_dependencies.add(parent) - - # no child can ever appear before its parent thanks to the queue (needed to add them at the correct location in the body later) - return all_dependencies_with_parent - - -class PostModularConverterCleaner(CSTTransformer): - """Allow simple cleaning after conversion. Remove top-level functions/classes without any calls (they may arise due - to dependency mapping, even if code parts with those functions/classes were overwritten)""" - - METADATA_DEPENDENCIES = (ParentNodeProvider,) - - def __init__(self, added_dependencies: set): - super().__init__() - self.top_level_functions_or_classes = {} - self.all_used_functions_or_classes = set() - self.added_dependencies = added_dependencies - - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.top_level_functions_or_classes[node.name.value] = node - - def visit_ClassDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.top_level_functions_or_classes[node.name.value] = node - - def visit_Name(self, node: cst.Name): - """This is used to find any mention of a top-level function or class except its own definition. - It will contain other names as well, but those will not be used. This is the most general way to do it - since mentions may appear in a lot of different contexts (apart from simple Call to the function/class). - e.g. Attention classes are only mentionned by their name in a dict assignment. - """ - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - - if not ( - (m.matches(parent_node, m.ClassDef()) and parent_node.name.value == node.value) - or (m.matches(parent_node, m.FunctionDef()) and parent_node.name.value == node.value) - ): - self.all_used_functions_or_classes.add(node.value) - - def leave_Module(self, original_node: cst.Module, node): - # Find any class/function that was mistakenly added as part of the dependencies and remove it - unused = self.added_dependencies - self.all_used_functions_or_classes - nodes_to_remove = [ - self.top_level_functions_or_classes[name] for name in unused if name in self.top_level_functions_or_classes - ] - new_body = [node_ for node_ in original_node.body if node_ not in nodes_to_remove] - # Return a new module with the updated body - return node.with_changes(body=new_body) + new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] + wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body)) + scopes = set(wrapper.resolve(ScopeProvider).values()) + unused_imports = set() + import_ref_count = {} + for scope in scopes: + for assignment in scope.assignments: + node = assignment.node + if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): + ref_count = len(assignment.references) + name = assignment.name + # Similar imports may be redefined, and only used between their 1st and 2nd definition + # so if we already have a ref count > 0, the imports is actually used + if (ref_count == 0 and import_ref_count.get(name, -1) <= 0) or name in body.keys(): + unused_imports.add(name) + import_ref_count[name] = ref_count + + imports_to_keep = [] + for node in all_imports: + if m.matches(node, m.If()): # handle safe imports + new_statements = [] + for stmt_node in node.body.body: + append_new_import_node(stmt_node, unused_imports, new_statements) + if len(new_statements) > 0: + new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) + imports_to_keep.append(new_node) + else: + append_new_import_node(node, unused_imports, imports_to_keep) + + protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())] + usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())] + # If the same import is both protected and unprotected, only keep the protected one + for protected_node in protected_import_nodes: + for stmt_node in protected_node.body.body: + usual_import_nodes = [node for node in usual_import_nodes if node.body[0] != stmt_node.body[0]] + + # Protected imports always appear at the end of all imports + return usual_import_nodes + protected_import_nodes + + +def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]: + """Split the `__all__` assignment found in the modular between each corresponding files.""" + all_all_per_file = {} + assign_node = node.body[0] + if isinstance(assign_node.value, cst.List): + # Extract the elements from the list + all_all_to_add = defaultdict(list) + for element in assign_node.value.elements: + if isinstance(element.value, cst.SimpleString): + # Remove quotes and add the string to the elements list + class_name = element.value.value + file = find_file_type(element.value.evaluated_value) + all_all_to_add[file] += [class_name] + for file, new_alls in all_all_to_add.items(): + new_node = assign_node.with_changes( + value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) + ) + all_all_per_file[file] = node.with_changes(body=[new_node]) + return all_all_per_file -class ModularConverterTransformer(CSTTransformer): - METADATA_DEPENDENCIES = (ParentNodeProvider, ScopeProvider, PositionProvider) +class ModularFileMapper(ModuleMapper): + """This is a Mapper to visit a modular file (like `modular_llama.py`). It visits the whole file, recording dependency, + then visits all imported modeling files (like `modeling_llama.py`), and manages their mutual dependencies. + Calling the method `create_modules()` after visit will create all modules based on this modular file. + """ def __init__(self, python_module, new_name, given_old_name=None, given_new_name=None): - super().__init__() - self.model_name = ( - new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` our `phi3` - ) + super().__init__(python_module) + # fmt: off + self.model_name = new_name # name of the model being defined. Should be in the format of `llama` or `layout_xlm` or `phi3` self.given_old_name = given_old_name self.given_new_name = given_new_name - # fmt: off - self.python_module = python_module # we store the original module to use `code_for_node` - self.transformers_imports = {} # maps the imports name like "from transformers.models.xxx" to the parsed AST module - self.imported_mapping = {} # stores the name of the imported classes, with their source {"LlamaModel":"transformers.model.llama.modeling_llama"} - self.visited_module = {} # modules visited like "transformers.models.llama.modeling_llama" - self.inserted_deps = [] # nodes inserted via super dependency - self.all_imports = [] # just stores all of the imports - self.all_safe_imports = [] # stores the import under simple statements - self.global_scope_index = 0 + + self.model_specific_imported_objects: Dict[str, str] = {} # e.g. {"LlamaModel": "transformers.models.llama.modeling_llama"} + self.model_specific_modules: Dict[str, cst.Module] = {} # e.g. {"transformers.models.llama.modeling_llama": cst.Module} + + self.all_all_to_add = {} # fmt: on - self.files = { # mapping for different component bodies - "modeling": {}, - "configuration": {}, - "tokenization": {}, - "processing": {}, - "image_processing": {}, - "feature_extractor": {}, - } - self.match_patterns = "|".join(self.files.keys()) - self.all_definitions = {} - self.class_to_file_type = {} - self.current_class = None # keep track of current top-level class during visit - self.current_top_level_function = None # keep track of current top-level function during visit - # Mapping from top-level functions to classes using them - self.function_call_class_mapping = defaultdict(lambda: set()) - # Mapping from top-level functions to other top-level functions dependencies - self.function_call_dependency_mapping = defaultdict(lambda: set()) - self.added_dependencies = set() def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - """When visiting imports from `transformers.models.xxx` we need to: - 1. Get the original source code - 2. Parse it into an AST Tree - 3. Add this import to `self.transformers_imports` as visited to not parse it twice + """When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it, + and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. """ - import_statement = self.python_module.code_for_node(node.module) + import_module = self.python_module.code_for_node(node.module) + import_statement = "." * len(node.relative) + import_module + if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): + return if m.matches(node.module, m.Attribute()): for imported_ in node.names: - _import = re.search(rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", import_statement) + _import = re.search( + rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement + ) if _import: - source = _import.groups()[0] + source = _import.group(1) if source == "modeling" and "Config" in self.python_module.code_for_node(imported_): raise ValueError( f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead" ) - if import_statement not in self.transformers_imports: - if "models" not in import_statement: - import_statement = "models." + import_statement - if "transformers" not in import_statement: - import_statement = "transformers." + import_statement - source_code = get_module_source_from_name(import_statement) + if import_module not in self.model_specific_modules: + if "models" not in import_module: + import_module = "models." + import_module + if "transformers" not in import_module: + import_module = "transformers." + import_module + source_code = get_module_source_from_name(import_module) tree = cst.parse_module(source_code) - self.transformers_imports[import_statement] = tree - imported_class = self.python_module.code_for_node(imported_.name) - self.imported_mapping[imported_class] = import_statement + self.model_specific_modules[import_module] = tree + imported_object = self.python_module.code_for_node(imported_.name) + self.model_specific_imported_objects[imported_object] = import_module if m.matches(node.module, m.Name()): - if "transformers" == import_statement: + if "transformers" == import_module: raise ValueError( - f"You are importing from {import_statement} directly using global imports. Import from the correct local path" + f"You are importing from {import_module} directly using global imports. Import from the correct local path" ) - def leave_SimpleStatementLine(self, original_node, updated_node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) + def visit_SimpleStatementLine(self, node): + """If we visit an import statement not previously visited, record it. If we visit a module-scope assignment, + simply record it or, if it is `__all__`, split it between files where we should dispatch it. + """ + parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) + simple_top_level_assign_structure = m.SimpleStatementLine( + body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] + ) if m.matches(parent_node, m.Module()): - if m.matches(updated_node, m.SimpleStatementLine(body=[m.Import()])): - if updated_node not in self.all_imports: - self.all_imports.append(updated_node) - return updated_node - elif m.matches(updated_node, m.SimpleStatementLine(body=[m.ImportFrom()])): - full_statement = self.python_module.code_for_node(updated_node.body[0].module) - if re.search( - rf"(transformers\.models\..|..)*\.({self.match_patterns})_.*", full_statement - ): # OR MATCH ..llama.modeling_llama - return cst.RemoveFromParent() - if updated_node not in self.all_imports: - self.all_imports.append(updated_node) - return updated_node - elif m.matches(original_node, m.SimpleStatementLine(body=[m.Assign()])): - if original_node.body[0].targets[0].target.value in ASSIGNMENTS_TO_KEEP.keys(): - file_ = ASSIGNMENTS_TO_KEEP[original_node.body[0].targets[0].target.value] - self.files[file_][original_node.body[0].targets[0].target.value] = { - "node": original_node, - "insert_idx": self.global_scope_index, - } - self.global_scope_index += 100 - return updated_node - - def visit_ClassDef(self, node: cst.ClassDef): - """Used to keep track of current class""" - self.current_class = node.name.value + if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): + self.imports.append(node) + elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): + import_module = self.python_module.code_for_node(node.body[0].module) + import_statement = "." * len(node.body[0].relative) + import_module + if not ( + re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement) + and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR) + ): + self.imports.append(node) + elif m.matches(node, simple_top_level_assign_structure): + assigned_variable = node.body[0].targets[0].target.value + # __all__ is treated differently and not added to general assignments + if assigned_variable == "__all__": + self.all_all_to_add = split_all_assignment(node) + else: + self.assignments[assigned_variable] = node - def leave_ClassDef(self, original_node, updated_node): + def leave_Module(self, node): + """When we leave the modular file, we do the following in order: + 1. compute the nested (recursive) function and assignment dependencies + 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update + its dependency graph with the new function and assignment definitions found in the modular + 3. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) """ - 1. Filter the `base` classes of this class - If they are from `transformers.models.xx` then: - - take the AST tree of the module it comes from and parse it with a `ClassFinder`. - - rename all every instance of `old_name` (llama) to `new_name` (gemma) - 2. We insert the modules which the inherited base depends on. This has to be done in - the order of the dependencies. If on is already in the new_body (because it's defined in the diff file) - then we remove it from the new body to add it again in the correct order. - 3. Replace the calls to `super().xxxx` merging parent code + # Takes care of finalizing our visit + super().leave_Module(node) + + # 1. compute the nested (recursive) function and assignment dependencies + self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() + + # 2. for each modeling file found in the imports, rename it with the new model name, visit it, and update dependencies + self.visited_modules = {} + self.renamers = {} + for file, module in self.model_specific_modules.items(): + file_model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", file).groups()[0] + renamer = ReplaceNameTransformer( + file_model_name, self.model_name, self.given_old_name, self.given_new_name + ) + renamed_module = module.visit(renamer) + self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( + renamed_module, + self.classes, + self.functions, + self.assignments, + self.object_dependency_mapping, + self.start_lines, + ) + # We record it so that we can rename classes later the exact same way + self.renamers[file] = renamer + + # 3. in turn, we need to add the imported functions/assignments to the dependencies of the modular mapper, using the + # definitions found in the visited files + self.merge_model_specific_imports(self.visited_modules) + + # We need to keep track of which objects were imported directly into which modeling file to not add them wrongly later + # Note that we may visit several of the same file types, thus we save them per file type, not file + self.imported_objects_per_file = defaultdict(set) + for file, mapper in self.visited_modules.items(): + file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})_.*", file).group(1) + self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling) + + def merge_model_specific_imports(self, visited_modules): + """Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph, + based on the visited files.""" + self.start_lines_file_mapping = {} + self.added_objects_file_mapping = {} + for object_name, file in self.model_specific_imported_objects.items(): + visited_module = visited_modules[file] + self.start_lines_file_mapping[file] = visited_module.start_lines + # Add functions and their dependencies + if object_name in visited_module.functions and object_name not in self.functions: + self.functions[object_name] = visited_module.functions[object_name] + self.added_objects_file_mapping[object_name] = file + dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + if dependencies is not None: + self.object_recursive_dependency_mapping[object_name] = dependencies + for dep in dependencies: + if dep not in self.global_nodes: + self.added_objects_file_mapping[dep] = file + self.functions[dep] = visited_module.global_nodes[dep] + + # Add assignments and their dependencies + elif object_name in visited_module.assignments and object_name not in self.assignments: + self.assignments[object_name] = visited_module.assignments[object_name] + self.added_objects_file_mapping[object_name] = file + dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, None) + if dependencies is not None: + self.object_recursive_dependency_mapping[object_name] = dependencies + for dep in dependencies: + if dep not in self.global_nodes: + self.added_objects_file_mapping[dep] = file + self.assignments[dep] = visited_module.global_nodes[dep] + + # Do not forget to re-assign all nodes after the merge + self.global_nodes = {**self.assignments, **self.classes, **self.functions} + + def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: + """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that + will be created based on the modular. """ - class_name = original_node.name.value - bases = [k.value.value for k in original_node.bases if k.value.value in self.imported_mapping] - all_bases = [k.value.value for k in original_node.bases] - self.global_scope_index += 100 - for super_class in bases: - if super_class not in self.imported_mapping: - raise ImportError( - f"{super_class} was not imported using `from transformers.models.xxxxx.modeling_xxxx import {super_class}" - ) - - super_file_name = self.imported_mapping[super_class] # we need to get the parsed tree - model_name = re.search(r"models\.\w*?\.\w*?_(\S*)", super_file_name) - if model_name: - model_name = model_name.groups()[0] + relative_order = {} + idx = 0 + + original_dependencies = [] + other_files_dependencies = defaultdict(list) + for dep in tuple(missing_dependencies): + if dep in self.added_objects_file_mapping: + file = self.added_objects_file_mapping[dep] + other_files_dependencies[file].append(dep) else: - raise ValueError( - f"Tried parsing the name of the imported package from {super_file_name}, could not extract the model name" - ) - file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] - visited_module = self.visited_module - if super_file_name not in visited_module: # only extract classes once - class_finder = find_classes_in_file( - self.transformers_imports[super_file_name], - model_name, - self.model_name, - self.given_old_name, - self.given_new_name, - ) - visited_module[super_file_name] = class_finder - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - else: # we are re-using the previously parsed data - class_finder = visited_module[super_file_name] - - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - if len(list_dependencies) == 0: - # so, maybe standard renaming did not work (the class name is different) - # we try with another renaming pattern - potential_given_name = get_new_part(class_name, super_class) - del visited_module[super_file_name] - class_finder = find_classes_in_file( - self.transformers_imports[super_file_name], - model_name, - potential_given_name, - self.model_name, - potential_given_name, - ) - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - if len(list_dependencies) == 0: - # last recourse, if the suffix of the new class is different from the one of the super class - # e.g. MyNewClassForSegmentation extends MyOldClassForObjectDetection - # we try with another renaming pattern - class_finder = find_classes_in_file( - self.transformers_imports[super_file_name], - model_name, - self.model_name, - self.given_old_name, - self.given_new_name, - super_class, - class_name, - ) - visited_module[super_file_name] = class_finder - list_dependencies = { - dep: class_finder.class_start_line.get(dep, 1000) - for dep in class_finder.class_dependency_mapping.get(class_name, []) - } - if len(list_dependencies) == 0: - raise ValueError( - f"We were unable to find dependencies for {class_name} (based on inheriting from {super_class})" - f" Here are all the global dependencies that we found in you modular file: {list(class_finder.class_dependency_mapping.keys())}." - f" This usually means that the name of `{class_name}` does not match the pattern of `{super_class}`" - ) + original_dependencies.append(dep) + # Sort all lists according to the order in their respective file + all_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines[x]) + for file, dependencies in other_files_dependencies.items(): + sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) + all_dependencies += sorted_dependencies + + # Add all original node first, then merged ones (one file at a time) + for dep in all_dependencies: + relative_order[dep] = idx + idx += 1 + + return relative_order + + +def check_dependencies_and_create_import_node( + file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str +) -> tuple[set[str], dict[str, cst.CSTNode]]: + """Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case, + we need to remove it from the dependencies, and create a new import to it instead. + This scenario may appear in the following case: + If a new class in the `modular_xxx.py` file does not belong to `type_xxx.py`, but is used somewhere in `other_type_xxx.py` + (e.g. as a type hint), but none of the visited files had a similar class, then it would be imported in `type_xxx.py` as + part of the standard dependency graph (because we never encountered an import towards this new class in any file). + For example imagine the following `modular.py`: + ``` + from ..llama.modeling_llama import LlamaModel - list_dependencies = sorted(list_dependencies.items(), key=lambda x: x[1], reverse=True) - start_insert_idx = self.global_scope_index - file_to_update = self.files[file_type] - is_empty_node = self.python_module.code_for_node(original_node.body) == "pass\n" - for dependency, _ in list_dependencies: - # we can write to the correct body, using the source of the parent class - node = class_finder.global_nodes.get(dependency, None) - if node is not None: - if dependency not in file_to_update: - node = self.all_definitions.pop(dependency, node) - start_insert_idx -= 1 - file_to_update[dependency] = {"insert_idx": start_insert_idx, "node": node} - self.added_dependencies.add(dependency) - elif dependency not in self.inserted_deps: - # make sure the node is written after its dependencies - start_insert_idx = file_to_update[dependency]["insert_idx"] - 1 - if ( - dependency in file_to_update.keys() - and dependency in class_finder.first_lvl_dependency_mapping[class_name] - ): - # If dependency is defined, but not used, raise error - calls = m.findall(original_node, m.Call(func=m.Name(dependency))) - if not calls and not is_empty_node and dependency not in all_bases: - raise ValueError( - f"""You defined `{dependency}` in the modular_{self.model_name}.py, it should be used - when you define `{class_name}`, as it is one of it's direct dependencies. Make sure - you use it in the `__init__` function.""" - ) - self.inserted_deps.append(dependency) - - if len(list_dependencies) > 0: - updated_node = replace_call_to_super(class_finder, updated_node, class_name, all_bases) - - # Now, if a class was defined without parents, we look for the name - match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) - match = re.search(rf"({match_pattern})$", class_name) - if match: - key = TYPE_TO_FILE_TYPE[match.group(1)] - self.class_to_file_type[class_name] = key - self.files[key][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} - else: - self.class_to_file_type[class_name] = "modeling" - self.files["modeling"][class_name] = {"insert_idx": self.global_scope_index, "node": updated_node} + class NewNameTextConfig(PretrainedConfig): + ... - self.current_class = None - return updated_node + class NewNameConfig(PretrainedConfig): + ... - def visit_FunctionDef(self, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) - if m.matches(parent_node, m.Module()): - self.current_top_level_function = node.name.value + class NewNameModel(LlamaModel): + config = NewNameConfig() + text_config = NewNameTextConfig() + ... + ``` + then without the help of this function, `NewNameTextConfig` would be imported in the `modeling_newname.py` as well as + `configuration_newname.py`, because `modeling_llama.py` tells us to not import `NewNameConfig`, but has no + knowledge of `NewNameTextConfig`. + """ + class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())} + corrected_dependencies = new_dependencies.copy() + new_imports = {} + for class_name in class_dependencies: + class_file_type = find_file_type(class_name) + # In this case, we need to remove it from the dependencies and create a new import instead + if class_file_type != file_type: + corrected_dependencies.remove(class_name) + import_statement = f"from .{class_file_type}_{new_name} import {class_name}" + new_imports[class_name] = cst.parse_statement(import_statement) + + return corrected_dependencies, new_imports + + +def get_class_node_and_dependencies( + modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] +) -> tuple[dict, str, dict]: + """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new + class node based on the inherited classes if needed. Also returns any new imports of a new class defined in + the modular that we nay need. + """ + bases = [k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects] + if len(bases) > 1: + raise ValueError( + f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {*bases,}." + ) - def leave_FunctionDef(self, original_node, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - if m.matches(parent_node, m.Module()): - self.all_definitions[node.name.value] = node - return node - - def visit_Assign(self, node: cst.Assign) -> None: - # Check if the assignment target is '__all__' - if isinstance(node.targets[0].target, cst.Name) and node.targets[0].target.value == "__all__": - if isinstance(node.value, cst.List): - # Extract the elements from the list - all_all_to_add = defaultdict(list) - for elt in node.value.elements: - if isinstance(elt.value, cst.SimpleString): - # Remove quotes and add the string to the elements list - class_name = elt.value.value - file = self.class_to_file_type[ - elt.value.evaluated_value - ] # evaluated value give the content of the string - all_all_to_add[file] += [class_name] - for f_type, new_alls in all_all_to_add.items(): - updated_node = node.with_changes( - value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) - ) - self.files[f_type][class_name] = { - "insert_idx": self.global_scope_index + 100, - "node": updated_node, - } - - def leave_If(self, original_node, node): - parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, original_node) - if m.matches(parent_node, m.Module()): - full_statement = self.python_module.code_for_node(original_node.test) - if re.search(r"[\s\S]*is_.*available", full_statement): - self.all_safe_imports.append(node) - elif full_statement not in self.all_imports: - logger.warning(f"one import is protected with `if`. Hard guess where it's used {full_statement}") - return node - - def visit_Call(self, node: cst.Call): - """This is used to create a mapping from functions to class calling them, and from top-level functions to functions called inside them. - Important note: we only rely on direct Call to the functions here, not indirect mentions (such as assigning a variable with the function, - add calling the variable later). This should be enough as the `modular_xxx` and `modeling_xxx` structures should be as simple as possible.""" - # Only map function calls if we're inside a class (i.e., current_class is set) - if self.current_class is not None: - # Simple function calls such as foo() - if isinstance(node.func, cst.Name): - self.function_call_class_mapping[node.func.value].add(self.current_class) - elif self.current_top_level_function is not None: - # Simple function calls such as foo() - if isinstance(node.func, cst.Name): - self.function_call_dependency_mapping[self.current_top_level_function].add(node.func.value) - - def _maybe_add_function_to_body( - self, - top_level_function: str, - body: dict, - function_node: cst.FunctionDef, - matching_callers: Optional[set] = None, - parent: Optional[str] = None, - ) -> bool: - """Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers` - is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return - `True`. Return `False` otherwise. - """ - if matching_callers is None and parent is None: - raise ValueError("Cannot add function if both the parent and the matching callers are None.") - if matching_callers is None: - matching_callers = {parent} - if len(matching_callers) > 0 and top_level_function not in body.keys(): - # Add the function just before the first class using it - new_idx = min([body[element]["insert_idx"] for element in matching_callers]) - # Reorder the elements - for element in body.keys(): - if body[element]["insert_idx"] >= new_idx: - body[element]["insert_idx"] += 1 - # Assign new element to body (after changing the count to avoid messing it) - body[top_level_function] = {"insert_idx": new_idx, "node": function_node} - return True - return False - - def _recursively_add_all_new_needed_functions_in_files(self): - """For all top-level functions which were newly defined in the `modular_xxx.py`, check if they are used in a class in - the different files, and add them to the file if it is the case (also recursively adding all other functions that - may be needed in that function body).""" - # At this point, `self.all_definitions` only contains newly defined top-level functions in the `modualr_xxx.py` - for top_level_function, function_node in self.all_definitions.items(): - calling_entities = self.function_call_class_mapping[top_level_function] - # The function may be needed in different files, we need to iterate on them - for file, body in self.files.items(): - file_elements = set(body.keys()) - # If the intersection is not null, top_level_func must be added to file - matching_callers = calling_entities & file_elements - added = self._maybe_add_function_to_body(top_level_function, body, function_node, matching_callers) - # If the function was added, we need to recursively add all its dependencies - if added: - for dependency, parent in find_all_dependencies( - top_level_function, self.function_call_dependency_mapping - ): - self._maybe_add_function_to_body( - dependency, body, self.all_definitions[dependency], parent=parent - ) + file_type = find_file_type(class_name) + file_to_update = files[file_type] + model_name = modular_mapper.model_name - def leave_Module(self, original_node: cst.Module, node): - imports = {self.python_module.code_for_node(k): k for k in self.all_imports} - dependency_imports = {file_type: imports.copy() for file_type in self.files} - for super_file_name, visiter in self.visited_module.items(): - file_type = re.search(r"models?\.\w*?\.(\w*?)_", super_file_name).groups()[0] - dependency_imports[file_type].update( - {self.python_module.code_for_node(k): k for k in visiter.imports.values()} - ) + # This is used to avoid adding objects to the dependencies graph if they will be imported already + imported_objects = modular_mapper.imported_objects_per_file[file_type] + + # We need to replace the class node with the transformers (modeling file) super class node + if len(bases) == 1: + super_class = bases[0] + super_file_name = modular_mapper.model_specific_imported_objects[super_class] + + # Get the mapper corresponding to the inherited class + mapper = modular_mapper.visited_modules[super_file_name] + # Rename the super class according to the exact same rule we used when renaming the whole module + renamer = modular_mapper.renamers[super_file_name] + renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.default_name) + renamed_super_class = convert_to_camelcase(renamed_super_class, renamer.old_name, renamer.default_old_name) + + # Create the new class node + updated_node = replace_class_node(mapper, node, renamed_super_class) + + # Grab all immediate dependencies of the new node + new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects) + + # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove + # it from the dependencies, and add a new import of it instead + new_node_dependencies, new_imports = check_dependencies_and_create_import_node( + file_type, new_node_dependencies, mapper, model_name + ) + + # The node was modified -> look for all recursive dependencies of the new node + all_dependencies_to_add = find_all_dependencies( + dependency_mapping=mapper.class_dependency_mapping, + initial_dependencies=new_node_dependencies, + initial_checked_dependencies=set(file_to_update.keys()), + ) + + relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) + nodes_to_add = { + dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add + } + + # No transformers (modeling file) super class, just check functions and assignments dependencies + else: + updated_node = node + # The node was NOT modified -> no need to look recursively for other class dependencies. Indeed, even if they are not + # already defined (which would mean a weird order of the code in the modular...), they will be in the future + all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects) + + # At this point, if any class dependency is found, but belongs to another file, it means that we need to remove + # it from the dependencies, and add a new import of it instead + all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node( + file_type, all_dependencies_to_add, modular_mapper, model_name + ) + + relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) + nodes_to_add = { + dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep]) + for dep in all_dependencies_to_add + if dep not in file_to_update.keys() + } + + # Add the class node itself to the nodes to add + class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 + nodes_to_add[class_name] = (class_idx, updated_node) + + return nodes_to_add, file_type, new_imports + + +def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: + """Create all the new modules based on visiting the modular file. It replaces all classes as necesary.""" + files = defaultdict(dict) + current_file_indices = defaultdict(lambda: 0) + + # For each class defined in modular, potentially replace the node and add it with its dependencies + for class_name, node in modular_mapper.classes.items(): + nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files) + + # Add the new potential new imports that we may need to the `modular_mapper` variable + modular_mapper.imported_objects_per_file[file_type].update(new_imports.keys()) + modular_mapper.imports.extend(list(new_imports.values())) + + # Sort the nodes according to their relative order + nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) + # Write all nodes to file + for dependency, (_, node) in nodes_to_add: + # This is used to keep certain variables at the beginning of the file + try: + # The -1000 is arbitrary -> just keep it bigger than the list + idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency) + except ValueError: + idx = current_file_indices[file_type] + current_file_indices[file_type] += 1 + files[file_type][dependency] = {"insert_idx": idx, "node": node} + + # Add the __all__ statement to files at the end + for file_type, node in modular_mapper.all_all_to_add.items(): + idx = current_file_indices[file_type] + files[file_type]["__all__"] = {"insert_idx": idx, "node": node} + + # Aggregate all the imports statements (we look for duplicates with the code_for_node, not the nodes themselves because + # they are wrapped in SimpleStatementLine or If which could have different newlines, blanks etc) + all_imports = modular_mapper.imports.copy() + all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports} + for file, mapper in modular_mapper.visited_modules.items(): + new_imports = [ + node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code + ] + new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} + all_imports.extend(new_imports) + all_imports_code.update(new_imports_code) - # Check if any new top-level function from the `modular_xxx.py` should be added to the different files - # (if it is called in a class in the file, then it will be copy pasted from `modular.py` to that file). - self._recursively_add_all_new_needed_functions_in_files() + # Find the correct imports, and write the new modules + for file, body in files.items(): + new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] + needed_imports = get_needed_imports(body, all_imports) + full_module = needed_imports + new_body + new_module = cst.Module(body=full_module, header=modular_mapper.python_module.header) + files[file] = new_module - for file, body in self.files.items(): - new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] - if len(new_body) > 0: - if file in dependency_imports.keys(): - new_body = list(dependency_imports[file].values()) + new_body - new_module = cst.Module(body=[*new_body], header=node.header) - # Final cleanup - new_module = MetadataWrapper(new_module).visit(PostModularConverterCleaner(self.added_dependencies)) - self.files[file] = new_module - return node + return files def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, cst_transformers=None): @@ -1137,10 +1442,10 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, module = cst.parse_module(code) wrapper = MetadataWrapper(module) if cst_transformers is None: - cst_transformers = ModularConverterTransformer(module, model_name, old_model_name, new_model_name) + cst_transformers = ModularFileMapper(module, model_name, old_model_name, new_model_name) wrapper.visit(cst_transformers) - for file, node in cst_transformers.files.items(): - if node != {}: + for file, module in create_modules(cst_transformers).items(): + if module != {}: # Get relative path starting from src/transformers/ relative_path = re.search( r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/") @@ -1149,7 +1454,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None, header = AUTO_GENERATED_MESSAGE.format( relative_path=relative_path, short_name=os.path.basename(relative_path) ) - ruffed_code = run_ruff(header + node.code, True) + ruffed_code = run_ruff(header + module.code, True) formatted_code = run_ruff(ruffed_code, False) output[file] = [formatted_code, ruffed_code] return output @@ -1180,7 +1485,7 @@ def save_modeling_file(modular_file, converted_file): parser = argparse.ArgumentParser() parser.add_argument( "--files_to_parse", - default=["src/transformers/models/roberta/modular_roberta.py"], + default=["src/transformers/models/gemma/modular_gemma.py"], nargs="+", help="A list of `modular_xxxx` files that should be converted to single model file", ) @@ -1197,6 +1502,7 @@ def save_modeling_file(modular_file, converted_file): args = parser.parse_args() if args.files_to_parse == ["all"]: args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) + args.files_to_parse += glob.glob("examples/**/modular_*.py", recursive=True) for file_name in find_priority_list(args.files_to_parse): print(f"Converting {file_name} to a single model single file format") From ac77a097e492ca83f85264414f800b23aad783ad Mon Sep 17 00:00:00 2001 From: pglorio Date: Wed, 20 Nov 2024 03:20:53 +0000 Subject: [PATCH 21/73] new modular conversion --- .../models/zamba2/configuration_zamba2.py | 2 - .../models/zamba2/modeling_zamba2.py | 358 +++++++++--------- 2 files changed, 188 insertions(+), 172 deletions(-) diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 5463ca6e8f5..81f2cccd5ee 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -19,8 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 79413fdf267..6a5f9019f66 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -24,22 +24,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import ( - AttentionMaskConverter, -) +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -48,29 +41,25 @@ logging, replace_return_docstrings, ) -from ...utils.import_utils import ( - is_causal_conv1d_available, - is_mamba_ssm_available, -) +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_zamba2 import Zamba2Config if is_mamba_ssm_available(): - from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined else: - selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None + selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: causal_conv1d_update, causal_conv1d_fn = None, None - -is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) - logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "Zamba2Config" + class Zamba2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -92,7 +81,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Zamba2DynamicCache(DynamicCache): +class Zamba2HybridDynamicCache(DynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -109,17 +98,33 @@ class Zamba2DynamicCache(DynamicCache): def __init__( self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): + self.dtype = dtype self.layers_block_type = config.layers_block_type + self.has_previous_state = False # only used by mamba + self.intermediate_size = config.mamba_expand * config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.n_mamba_heads = config.n_mamba_heads self.transformer_layers = [] self._modules = {} self._parameters = {} self._buffers = {} + for i in range(config.num_hidden_layers): + self.conv_states += [ + torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype) + ] + cache_shape = ( + batch_size, + self.n_mamba_heads, + self.intermediate_size // self.n_mamba_heads, + self.ssm_state_size, + ) + self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] + if self.layers_block_type[i] == "hybrid": + self.transformer_layers.append(i) - self.has_previous_state = False - self.dtype = dtype - self.conv_kernel_size = config.mamba_d_conv - self.intermediate_size = int(config.mamba_expand * config.hidden_size) - + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] self.conv_states = { i: torch.zeros( batch_size, @@ -130,38 +135,18 @@ def __init__( ) for i in range(config.num_hidden_layers) } - num_heads = self.intermediate_size // config.mamba_headdim self.ssm_states = { i: torch.zeros( - batch_size, num_heads, config.mamba_headdim, config.mamba_d_state, device=device, dtype=dtype + batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype ) for i in range(config.num_hidden_layers) } + # for i in range(config.num_hidden_layers): + # if self.layers_block_type[i] == "hybrid": + # self.transformer_layers.append(i) + # self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + # self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - for i in range(config.num_hidden_layers): - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor - ) -> torch.Tensor: - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] - - def reset(self): - self.conv_states.zero_() - self.ssm_states.zero_() - - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update def update( self, key_states: torch.Tensor, @@ -179,7 +164,6 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): @@ -193,7 +177,6 @@ def reorder_cache(self, beam_idx: torch.LongTensor): device = self.ssm_states[layer_idx].device self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # take any layer that contains cache and not empty tensor @@ -203,29 +186,27 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return self.key_cache[layer_idx].shape[-2] def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") + raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.") @classmethod def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - raise NotImplementedError("Zamba2DynamicCache does not have a legacy cache equivalent.") + raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.") + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor + ) -> torch.Tensor: + conv_state = self.conv_states[layer_idx] + cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) -class MambaRMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states, gate=None): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - - if gate is not None: - hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + conv_state = conv_state.roll(shifts=-1, dims=-1) + conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + self.conv_states[layer_idx].zero_() + self.conv_states[layer_idx] += conv_state + return self.conv_states[layer_idx] - return self.weight * hidden_states.to(input_dtype) + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() class Zamba2RotaryEmbedding(nn.Module): @@ -282,42 +263,6 @@ def layer_type_list(config: Zamba2Config): return ll -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class Zamba2Attention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -345,7 +290,11 @@ class Zamba2Attention(nn.Module): """ def __init__( - self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks=None, block_id: int = None + self, + config: Zamba2Config, + layer_idx: Optional[int] = None, + num_fwd_mem_blocks: int = None, + block_id: int = None, ): super().__init__() self.config = config @@ -369,6 +318,7 @@ def __init__( self.q_proj = nn.Linear(self.attention_hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.num_fwd_mem_blocks = num_fwd_mem_blocks self.rope_theta = config.rope_theta self.layer_block_map = layer_type_list(config) @@ -413,7 +363,6 @@ def __init__( ) self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) def forward( self, @@ -421,7 +370,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -497,6 +446,40 @@ def forward( return attn_output, attn_weights, past_key_value +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: # Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward # dropped use_sliding_windows from the arguments of self._flash_attention_forward @@ -521,7 +504,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -637,7 +620,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -738,7 +721,7 @@ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): Assumes that we only have tensors of either size 4 or 3 """ - pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if input_tensor.ndim == 4 else (0, 0, 0, pad_size, 0, 0) + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) @@ -753,7 +736,7 @@ def reshape_into_chunks(input_tensor, pad_size, chunk_size): # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] input_tensor = pad_tensor_by_size(input_tensor, pad_size) - if input_tensor.ndim == 3: + if len(input_tensor.shape) == 3: # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) else: @@ -783,6 +766,9 @@ def segment_sum(input_tensor): return tensor_segsum +is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update)) + + class Zamba2MambaMixer(nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. @@ -856,7 +842,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): def cuda_kernels_forward( self, hidden_states: torch.Tensor, - cache_params: Optional[Zamba2DynamicCache] = None, + cache_params: Optional[Zamba2HybridDynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -1005,7 +991,7 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Optional[Zamba2DynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection @@ -1195,7 +1181,7 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2DynamicCache] def forward( self, hidden_states, - cache_params: Optional[Zamba2DynamicCache] = None, + cache_params: Optional[Zamba2HybridDynamicCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -1209,13 +1195,6 @@ def forward( return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) -ZAMBA2_ATTENTION_CLASSES = { - "eager": Zamba2Attention, - "flash_attention_2": Zamba2FlashAttention2, - "sdpa": Zamba2SdpaAttention, -} - - class Zamba2MLP(nn.Module): def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int = None): """ @@ -1223,23 +1202,23 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int is tied, un-tied LoRA modules are added to the up and gate projectors to increase expressivity with a small memory overhead. """ super().__init__() - self.config = config self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.act_fn = ACT2FN[config.hidden_act] + self.config = config + # self.hidden_size = config.hidden_size + # self.intermediate_size = config.intermediate_size + # self.act_fn = ACT2FN[config.hidden_act] self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id - self.ffn_intermediate_size = config.intermediate_size - - self.act_fn = ACT2FN[config.hidden_act] def gated_act_fn(x): x = torch.chunk(x, 2, dim=-1) - return self.act_fn(x[0]) * x[1] self.gated_act_fn = gated_act_fn - - self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.ffn_intermediate_size, bias=config.add_bias_linear) - self.down_proj = nn.Linear(self.ffn_intermediate_size, self.hidden_size, bias=config.add_bias_linear) + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) if self.config.use_shared_mlp_lora: self.gate_up_proj_lora_A_list = nn.ModuleList([]) @@ -1247,7 +1226,7 @@ def gated_act_fn(x): for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias=False) - gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.ffn_intermediate_size, bias=False) + gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.intermediate_size, bias=False) else: gate_up_proj_lora_A = nn.Identity() gate_up_proj_lora_B = nn.Identity() @@ -1285,17 +1264,26 @@ def count_mem_blocks_in_config(config: Zamba2Config): return num_gs +ZAMBA2_ATTENTION_CLASSES = { + "eager": Zamba2Attention, + "flash_attention_2": Zamba2FlashAttention2, + "sdpa": Zamba2SdpaAttention, +} + + class Zamba2AttentionDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None): super().__init__() + self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) num_gs = count_mem_blocks_in_config(config) self.block_id = block_id self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation]( config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id ) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) - self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) - self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) + # self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -1304,7 +1292,7 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1319,7 +1307,7 @@ def forward( (see fig. 2 in https://arxiv.org/pdf/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1342,9 +1330,9 @@ def forward( cache_position=cache_position, **kwargs, ) - + # feed-forward (MLP) hidden_states = self.pre_ff_layernorm(hidden_states) - hidden_states = self.feed_forward(hidden_states, layer_idx) + hidden_states = self.feed_forward(hidden_states) outputs = (hidden_states,) @@ -1372,7 +1360,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1383,7 +1371,7 @@ def forward( hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1426,11 +1414,15 @@ def forward( class Zamba2HybridLayer(nn.Module): - def __init__(self, shared_transf: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer): + def __init__( + self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer + ): super().__init__() - self.shared_transf = shared_transf self.linear = linear self.mamba_decoder = mamba + self.shared_transformer = shared_transformer + # self.linear = linear + # self.mamba_decoder = mamba def forward( self, @@ -1440,7 +1432,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2DynamicCache] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1453,7 +1445,7 @@ def forward( layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2DynamicCache`, *optional*): cached past key and value projection states + past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1464,7 +1456,7 @@ def forward( Indices depicting the position of the input sequence tokens in the sequence. """ - layer_outputs = self.shared_transf( + layer_outputs = self.shared_transformer( hidden_states, original_hidden_states=original_hidden_states, layer_idx=layer_idx, @@ -1529,7 +1521,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = False - _supports_cache_class = True # Note: only supports Zamba2DynamicCache + _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache _is_stateful = True def _init_weights(self, module): @@ -1558,8 +1550,28 @@ def _init_weights(self, module): module.dt_bias.copy_(inv_dt) module.dt_bias._no_reinit = True + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ): + """ + Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 2 by default on Zamba2 models. + Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba2 v1. + """ + config = super()._check_and_enable_flash_attn_2( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) -_CONFIG_FOR_DOC = "Zamba2Config" + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "flash_attention_2": + config._attn_implementation = "eager" + + return config ZAMBA2_INPUTS_DOCSTRING = r""" @@ -1597,14 +1609,14 @@ def _init_weights(self, module): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`Zamba2DynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A Zamba2DynamicCache object containing pre-computed hidden-states (keys and values in the + past_key_values (`Zamba2HybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + A Zamba2HybridDynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and `(batch_size, d_inner, d_state)` respectively. - See the `Zamba2DynamicCache` class for more details. + See the `Zamba2HybridDynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1670,14 +1682,14 @@ def __init__(self, config: Zamba2Config): if config.num_mem_blocks * len(layer_type_list(config)) > 1: prefix_name = f"layers.{layer_id}." tied_keys = [ - "shared_transf.self_attn.q_proj.weight", - "shared_transf.self_attn.k_proj.weight", - "shared_transf.self_attn.v_proj.weight", - "shared_transf.self_attn.o_proj.weight", - "shared_transf.feed_forward.gate_up_proj.weight", - "shared_transf.feed_forward.down_proj.weight", - "shared_transf.input_layernorm.weight", - "shared_transf.pre_ff_layernorm.weight", + "shared_transformer.self_attn.q_proj.weight", + "shared_transformer.self_attn.k_proj.weight", + "shared_transformer.self_attn.v_proj.weight", + "shared_transformer.self_attn.o_proj.weight", + "shared_transformer.feed_forward.gate_up_proj.weight", + "shared_transformer.feed_forward.down_proj.weight", + "shared_transformer.input_layernorm.weight", + "shared_transformer.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] if config.use_shared_mlp_lora: @@ -1686,10 +1698,14 @@ def __init__(self, config: Zamba2Config): for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: tied_keys_lora.append( - "shared_transf.feed_forward.gate_up_proj_lora_A_list." + str(lora_id) + ".weight" + "shared_transformer.feed_forward.gate_up_proj_lora_A_list." + + str(lora_id) + + ".weight" ) tied_keys_lora.append( - "shared_transf.feed_forward.gate_up_proj_lora_B_list." + str(lora_id) + ".weight" + "shared_transformer.feed_forward.gate_up_proj_lora_B_list." + + str(lora_id) + + ".weight" ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] @@ -1699,22 +1715,22 @@ def __init__(self, config: Zamba2Config): for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: tied_keys_lora.append( - "shared_transf.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" ) tied_keys_lora.append( - "shared_transf.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" + "shared_transformer.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] @@ -1743,7 +1759,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Zamba2DynamicCache] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1780,7 +1796,7 @@ def forward( if use_cache and past_key_values is None: logger.warning_once( - "Zamba2 requires an initialized `Zamba2DynamicCache` to return a cache. None was " + "Zamba2 requires an initialized `Zamba2HybridDynamicCache` to return a cache. None was " "provided, so no cache will be returned." ) @@ -1926,7 +1942,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Zamba2DynamicCache] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -2020,7 +2036,7 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # Overwitten -- has a unique cache type, `Zamba2DynamicCache` + # Overwitten -- has a unique cache type, `Zamba2HybridDynamicCache` empty_past_kv = past_key_values is None @@ -2034,7 +2050,9 @@ def prepare_inputs_for_generation( elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - past_key_values = Zamba2DynamicCache(self.config, input_ids.shape[0], dtype=self.dtype, device=self.device) + past_key_values = Zamba2HybridDynamicCache( + self.config, input_ids.shape[0], dtype=self.dtype, device=self.device + ) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation From e59980e39bf9841f805c008d04580e84d42de266 Mon Sep 17 00:00:00 2001 From: pglorio Date: Wed, 20 Nov 2024 03:25:49 +0000 Subject: [PATCH 22/73] fix generated modeling file --- src/transformers/models/zamba2/modeling_zamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 6a5f9019f66..dc38ff66aea 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -43,7 +43,7 @@ ) from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_zamba2 import Zamba2Config - +from ..mamba2.modeling_mamba2 import MambaRMSNormGated if is_mamba_ssm_available(): from mamba_ssm.ops.triton.selective_state_update import selective_state_update From 73a647aa807302c9e5c0f8f67af95df80a4962bd Mon Sep 17 00:00:00 2001 From: pglorio Date: Wed, 20 Nov 2024 20:52:24 +0000 Subject: [PATCH 23/73] fixed import for Zamba2RMSNormGated --- .../models/zamba2/modeling_zamba2.py | 55 ++++---- .../models/zamba2/modular_zamba2.py | 122 +++++++++--------- 2 files changed, 90 insertions(+), 87 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index dc38ff66aea..a74ef174d49 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -43,7 +43,7 @@ ) from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_zamba2 import Zamba2Config -from ..mamba2.modeling_mamba2 import MambaRMSNormGated + if is_mamba_ssm_available(): from mamba_ssm.ops.triton.selective_state_update import selective_state_update @@ -61,24 +61,22 @@ _CONFIG_FOR_DOC = "Zamba2Config" -class Zamba2RMSNorm(nn.Module): +class Zamba2RMSNormGated(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): - """ - Zamba2RMSNorm is equivalent to T5LayerNorm - """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward(self, hidden_states): + def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) + + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + return self.weight * hidden_states.to(input_dtype) class Zamba2HybridDynamicCache(DynamicCache): @@ -109,19 +107,6 @@ def __init__( self._modules = {} self._parameters = {} self._buffers = {} - for i in range(config.num_hidden_layers): - self.conv_states += [ - torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype) - ] - cache_shape = ( - batch_size, - self.n_mamba_heads, - self.intermediate_size // self.n_mamba_heads, - self.ssm_state_size, - ) - self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] - if self.layers_block_type[i] == "hybrid": - self.transformer_layers.append(i) self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] @@ -826,7 +811,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True - self.norm = MambaRMSNormGated(self.intermediate_size, eps=1e-5) + self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5) self.D = nn.Parameter(torch.ones(self.num_heads)) self.D._no_weight_decay = True @@ -1253,6 +1238,26 @@ def forward(self, hidden_state, layer_idx=None): return output +class Zamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Zamba2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + def count_mem_blocks_in_config(config: Zamba2Config): """ Count number of shared blocks @@ -1330,9 +1335,9 @@ def forward( cache_position=cache_position, **kwargs, ) - # feed-forward (MLP) + hidden_states = self.pre_ff_layernorm(hidden_states) - hidden_states = self.feed_forward(hidden_states) + hidden_states = self.feed_forward(hidden_states, layer_idx) outputs = (hidden_states,) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 1d1888e27c3..a14b3a8b020 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -22,7 +22,6 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import DynamicCache from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward @@ -283,10 +282,9 @@ def __init__( self.num_logits_to_keep = num_logits_to_keep -class Zamba2RMSNorm(ZambaRMSNorm): +class Zamba2RMSNormGated(MambaRMSNormGated): pass - def count_mem_blocks_in_config(config: Zamba2Config): """ Count number of shared blocks @@ -962,7 +960,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True - self.norm = MambaRMSNormGated(self.intermediate_size, eps=1e-5) + self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5) self.D = nn.Parameter(torch.ones(self.num_heads)) self.D._no_weight_decay = True @@ -1403,64 +1401,64 @@ def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Option # self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) # self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # def forward( - # self, - # hidden_states: torch.Tensor, - # original_hidden_states: torch.Tensor, - # layer_idx: int, - # attention_mask: Optional[torch.Tensor] = None, - # position_ids: Optional[torch.LongTensor] = None, - # past_key_value: Optional[Zamba2HybridDynamicCache] = None, - # output_attentions: Optional[bool] = False, - # use_cache: Optional[bool] = False, - # cache_position: Optional[torch.LongTensor] = None, - # **kwargs, - # ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - # """ - # Args: - # hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` - # original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. - # This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The - # concatenated tensor is then used as input of the pre-attention RMSNorm - # (see fig. 2 in https://arxiv.org/pdf/2405.16712). - # attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - # `(batch, sequence_length)` where padding elements are indicated by 0. - # past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states - # output_attentions (`bool`, *optional*): - # Whether or not to return the attentions tensors of all attention layers. See `attentions` under - # returned tensors for more detail. - # use_cache (`bool`, *optional*): - # If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - # (see `past_key_values`). - # cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - # Indices depicting the position of the input sequence tokens in the sequence. - # """ - # hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) - # hidden_states = self.input_layernorm(hidden_states) - # hidden_states, self_attn_weights, present_key_value = self.self_attn( - # hidden_states=hidden_states, - # layer_idx=layer_idx, - # attention_mask=attention_mask, - # position_ids=position_ids, - # past_key_value=past_key_value, - # output_attentions=output_attentions, - # use_cache=use_cache, - # cache_position=cache_position, - # **kwargs, - # ) - - # hidden_states = self.pre_ff_layernorm(hidden_states) - # hidden_states = self.feed_forward(hidden_states, layer_idx) - - # outputs = (hidden_states,) - - # if output_attentions: - # outputs += (self_attn_weights,) - - # if use_cache: - # outputs += (present_key_value,) - - # return outputs + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. + This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The + concatenated tensor is then used as input of the pre-attention RMSNorm + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + """ + hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states, layer_idx) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs class Zamba2MambaDecoderLayer(ZambaMambaDecoderLayer): From c2b72a5b8f154808cd625691f9435754752ae18d Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 21 Nov 2024 00:02:03 +0000 Subject: [PATCH 24/73] modular file cleanup --- .../models/zamba2/modeling_zamba2.py | 64 +++--- .../models/zamba2/modular_zamba2.py | 193 ++---------------- 2 files changed, 48 insertions(+), 209 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index a74ef174d49..2d318d714fc 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -79,6 +79,26 @@ def forward(self, hidden_states, gate=None): return self.weight * hidden_states.to(input_dtype) +class Zamba2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Zamba2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + class Zamba2HybridDynamicCache(DynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache @@ -98,8 +118,8 @@ def __init__( ): self.dtype = dtype self.layers_block_type = config.layers_block_type - self.has_previous_state = False # only used by mamba - self.intermediate_size = config.mamba_expand * config.hidden_size + self.has_previous_state = False + self.intermediate_size = int(config.mamba_expand * config.hidden_size) self.ssm_state_size = config.mamba_d_state self.conv_kernel_size = config.mamba_d_conv self.n_mamba_heads = config.n_mamba_heads @@ -107,9 +127,6 @@ def __init__( self._modules = {} self._parameters = {} self._buffers = {} - - self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] self.conv_states = { i: torch.zeros( batch_size, @@ -126,11 +143,11 @@ def __init__( ) for i in range(config.num_hidden_layers) } - # for i in range(config.num_hidden_layers): - # if self.layers_block_type[i] == "hybrid": - # self.transformer_layers.append(i) - # self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - # self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "hybrid": + self.transformer_layers.append(i) + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] def update( self, @@ -1191,9 +1208,6 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int self.intermediate_size = config.intermediate_size self.act_fn = ACT2FN[config.hidden_act] self.config = config - # self.hidden_size = config.hidden_size - # self.intermediate_size = config.intermediate_size - # self.act_fn = ACT2FN[config.hidden_act] self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id @@ -1238,26 +1252,6 @@ def forward(self, hidden_state, layer_idx=None): return output -class Zamba2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Zamba2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - def count_mem_blocks_in_config(config: Zamba2Config): """ Count number of shared blocks @@ -1287,8 +1281,6 @@ def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Option config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id ) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) - # self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) - # self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -1426,8 +1418,6 @@ def __init__( self.linear = linear self.mamba_decoder = mamba self.shared_transformer = shared_transformer - # self.linear = linear - # self.mamba_decoder = mamba def forward( self, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index a14b3a8b020..c01c0c7a543 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -26,7 +26,6 @@ from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( add_start_docstrings, is_flash_attn_greater_or_equal_2_10, @@ -285,6 +284,11 @@ def __init__( class Zamba2RMSNormGated(MambaRMSNormGated): pass + +class Zamba2RMSNorm(ZambaRMSNorm): + pass + + def count_mem_blocks_in_config(config: Zamba2Config): """ Count number of shared blocks @@ -309,60 +313,6 @@ def layer_type_list(config: Zamba2Config): return ll -# # Helper methods for segment sum computation - - -# def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): -# """ -# Padding x tensor with `pad_size` on the seq_len dim (dim=1) - -# Assumes that we only have tensors of either size 4 or 3 -# """ -# pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if input_tensor.ndim == 4 else (0, 0, 0, pad_size, 0, 0) - -# return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) - - -# def reshape_into_chunks(input_tensor, pad_size, chunk_size): -# """ -# Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and -# simultaneously splitting it into chunk sequences. - -# Assumes that we only have tensors of either size 4 or 3 -# """ -# # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] -# input_tensor = pad_tensor_by_size(input_tensor, pad_size) - -# if input_tensor.ndim == 3: -# # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] -# return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) -# else: -# # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] -# return input_tensor.reshape( -# input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] -# ) - - -# def segment_sum(input_tensor): -# """ -# More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. -# """ -# chunk_size = input_tensor.size(-1) -# # 1. expand input tensor to have an additional dimension and repeat along that dimension -# # [..., chunk_size] -> [..., chunk_size, chunk_size] -# input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) -# # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag -# mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) -# input_tensor = input_tensor.masked_fill(~mask, 0) -# # 3. compute actual cumsum -# tensor_segsum = torch.cumsum(input_tensor, dim=-2) - -# # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) -# mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) -# tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) -# return tensor_segsum - - class Zamba2HybridDynamicCache(ZambaHybridDynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache @@ -380,20 +330,17 @@ class Zamba2HybridDynamicCache(ZambaHybridDynamicCache): def __init__( self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None ): - super().__init__(config, batch_size, dtype, device) - # self.dtype = dtype - # self.layers_block_type = config.layers_block_type - # self.has_previous_state = False - # self.intermediate_size = int(config.mamba_expand * config.hidden_size) - # ssm_state_size = config.mamba_d_state - # conv_kernel_size = config.mamba_d_conv - # self.n_mamba_heads = config.n_mamba_heads - # self.transformer_layers = [] - # self._modules = {} - # self._parameters = {} - # self._buffers = {} - del self.conv_states - del self.ssm_states + self.dtype = dtype + self.layers_block_type = config.layers_block_type + self.has_previous_state = False + self.intermediate_size = int(config.mamba_expand * config.hidden_size) + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.n_mamba_heads = config.n_mamba_heads + self.transformer_layers = [] + self._modules = {} + self._parameters = {} + self._buffers = {} self.conv_states = { i: torch.zeros( batch_size, @@ -410,11 +357,11 @@ def __init__( ) for i in range(config.num_hidden_layers) } - # for i in range(config.num_hidden_layers): - # if self.layers_block_type[i] == "hybrid": - # self.transformer_layers.append(i) - # self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] - # self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + for i in range(config.num_hidden_layers): + if self.layers_block_type[i] == "hybrid": + self.transformer_layers.append(i) + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor @@ -432,61 +379,6 @@ def reset(self): self.conv_states.zero_() self.ssm_states.zero_() - # # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.update - # def update( - # self, - # key_states: torch.Tensor, - # value_states: torch.Tensor, - # layer_idx: int, - # cache_kwargs: Optional[Dict[str, Any]] = None, - # ) -> Tuple[torch.Tensor, torch.Tensor]: - # # Update the cache - # if self.key_cache[layer_idx].shape[-1] == 0: - # self.key_cache[layer_idx] = key_states - # self.value_cache[layer_idx] = value_states - # else: - # self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) - # self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) - - # return self.key_cache[layer_idx], self.value_cache[layer_idx] - - # # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.reorder_cache - # def reorder_cache(self, beam_idx: torch.LongTensor): - # """Reorders the cache for beam search, given the selected beam indices.""" - # for layer_idx in range(len(self.key_cache)): - # device = self.key_cache[layer_idx].device - # self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - # device = self.value_cache[layer_idx].device - # self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - # device = self.conv_states[layer_idx].device - # self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) - # device = self.ssm_states[layer_idx].device - # self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) - - # # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.get_seq_length - # def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - # """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # # take any layer that contains cache and not empty tensor - # layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - # if len(self.key_cache) <= layer_idx: - # return 0 - # return self.key_cache[layer_idx].shape[-2] - - - - - ### check the two methods below - # def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - # raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.") - - # @classmethod - # def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - # raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.") - - -ALL_LAYERNORM_LAYERS.append(Zamba2RMSNorm) - class Zamba2RotaryEmbedding(nn.Module): def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device=None): @@ -1337,9 +1229,6 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int """ super().__init__(config) self.config = config - # self.hidden_size = config.hidden_size - # self.intermediate_size = config.intermediate_size - # self.act_fn = ACT2FN[config.hidden_act] self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id @@ -1398,8 +1287,6 @@ def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Option config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id ) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) - # self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) - # self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -1473,8 +1360,6 @@ def __init__(self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.L super().__init__(shared_transformer, linear, mamba) del self.shared_transf self.shared_transformer = shared_transformer - # self.linear = linear - # self.mamba_decoder = mamba def forward( self, @@ -1565,42 +1450,6 @@ def forward( "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", ZAMBA2_START_DOCSTRING, ) -# class Zamba2PreTrainedModel(PreTrainedModel): -# config_class = Zamba2Config -# base_model_prefix = "model" -# supports_gradient_checkpointing = True -# _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] -# _skip_keys_device_placement = "past_key_values" -# _supports_flash_attn_2 = True -# _supports_sdpa = False -# _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache -# _is_stateful = True - -# def _init_weights(self, module): -# std = self.config.initializer_range -# if isinstance(module, (nn.Linear, nn.Conv1d)): -# module.weight.data.normal_(mean=0.0, std=std) -# if module.bias is not None: -# module.bias.data.zero_() -# elif isinstance(module, nn.Embedding): -# module.weight.data.normal_(mean=0.0, std=std) -# if module.padding_idx is not None: -# module.weight.data[module.padding_idx].zero_() -# elif isinstance(module, Zamba2MambaMixer): -# module.A_log._no_weight_decay = True -# module.D._no_weight_decay = True - -# dt = torch.exp( -# torch.rand(self.config.n_mamba_heads) -# * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) -# + math.log(self.config.time_step_min) -# ).clamp(min=self.config.time_step_floor) -# # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 -# inv_dt = dt + torch.log(-torch.expm1(-dt)) - -# with torch.no_grad(): -# module.dt_bias.copy_(inv_dt) -# module.dt_bias._no_reinit = True class Zamba2PreTrainedModel(ZambaPreTrainedModel): config_class = Zamba2Config _supports_flash_attn_2 = True From 10a0b1e1bcc8d8b541ea93b9bddeafc20e695858 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 21 Nov 2024 01:02:23 +0000 Subject: [PATCH 25/73] make fixup and model tests --- .../models/zamba/modeling_zamba.py | 9 ++-- .../models/zamba2/configuration_zamba2.py | 1 + .../models/zamba2/modular_zamba2.py | 41 ++++++++++++------- tests/models/zamba/test_modeling_zamba.py | 4 +- tests/models/zamba2/test_modeling_zamba2.py | 4 +- 5 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 115e4b19a00..b4a3a271d00 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -145,7 +145,12 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.conv_states += [ torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype) ] - cache_shape = (batch_size, self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, self.ssm_state_size) + cache_shape = ( + batch_size, + self.n_mamba_heads, + self.intermediate_size // self.n_mamba_heads, + self.ssm_state_size, + ) self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)] if self.layers_block_type[i] == "hybrid": self.transformer_layers.append(i) @@ -194,12 +199,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.to_legacy_cache def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.") @classmethod - # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.from_legacy_cache def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.") diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 81f2cccd5ee..25546afa52a 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -19,6 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index c01c0c7a543..987f37c8a62 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -15,13 +15,12 @@ # limitations under the License. import math from itertools import cycle -from typing import Any, Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from ...activations import ACT2FN from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward @@ -35,22 +34,23 @@ is_causal_conv1d_available, is_mamba_ssm_available, ) +from ..llama.modeling_llama import apply_rotary_pos_emb +from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum from ..zamba.modeling_zamba import ( ZambaAttention, + ZambaAttentionDecoderLayer, ZambaForCausalLM, ZambaForSequenceClassification, + ZambaHybridDynamicCache, + ZambaHybridLayer, ZambaMambaDecoderLayer, - ZambaModel, - ZambaRMSNorm, ZambaMLP, - ZambaAttentionDecoderLayer, - ZambaHybridLayer, + ZambaModel, ZambaPreTrainedModel, - ZambaHybridDynamicCache, + ZambaRMSNorm, repeat_kv, ) -from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum -from ..llama.modeling_llama import rotate_half, apply_rotary_pos_emb + if is_mamba_ssm_available(): from mamba_ssm.ops.triton.selective_state_update import selective_state_update @@ -425,7 +425,11 @@ class Zamba2Attention(ZambaAttention): """ def __init__( - self, config: Zamba2Config, layer_idx: Optional[int] = None, num_fwd_mem_blocks: int = None, block_id: int = None + self, + config: Zamba2Config, + layer_idx: Optional[int] = None, + num_fwd_mem_blocks: int = None, + block_id: int = None, ): super().__init__(config, layer_idx) self.num_fwd_mem_blocks = num_fwd_mem_blocks @@ -1235,6 +1239,7 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int def gated_act_fn(x): x = torch.chunk(x, 2, dim=-1) return self.act_fn(x[0]) * x[1] + self.gated_act_fn = gated_act_fn del self.gate_proj @@ -1356,7 +1361,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int): class Zamba2HybridLayer(ZambaHybridLayer): - def __init__(self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer): + def __init__( + self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer + ): super().__init__(shared_transformer, linear, mamba) del self.shared_transf self.shared_transformer = shared_transformer @@ -1624,10 +1631,14 @@ def __init__(self, config: Zamba2Config): for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: tied_keys_lora.append( - "shared_transformer.feed_forward.gate_up_proj_lora_A_list." + str(lora_id) + ".weight" + "shared_transformer.feed_forward.gate_up_proj_lora_A_list." + + str(lora_id) + + ".weight" ) tied_keys_lora.append( - "shared_transformer.feed_forward.gate_up_proj_lora_B_list." + str(lora_id) + ".weight" + "shared_transformer.feed_forward.gate_up_proj_lora_B_list." + + str(lora_id) + + ".weight" ) lora_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] @@ -1707,7 +1718,9 @@ def prepare_inputs_for_generation( elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] else: - past_key_values = Zamba2HybridDynamicCache(self.config, input_ids.shape[0], dtype=self.dtype, device=self.device) + past_key_values = Zamba2HybridDynamicCache( + self.config, input_ids.shape[0], dtype=self.dtype, device=self.device + ) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index 20743c796f0..ee47f98a1f4 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -215,9 +215,7 @@ def create_and_check_decoder_model_past_large_inputs( # first forward pass # Attention: Zamba needs the cache to be initialized to return a cache! - past_key_values = ZambaHybridDynamicCache( - config, input_ids.shape[0], model.dtype, device=model.device - ) + past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index e3ca547923b..285ae92b2b3 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -46,7 +46,7 @@ Zamba2Model, ) from transformers.models.zamba2.modeling_zamba2 import ( - Zamba2DynamicCache, + Zamba2HybridDynamicCache, ) @@ -221,7 +221,7 @@ def create_and_check_decoder_model_past_large_inputs( # first forward pass # Attention: Zamba2 needs the cache to be initialized to return a cache! - past_key_values = Zamba2DynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) + past_key_values = Zamba2HybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device) outputs = model( input_ids, attention_mask=input_mask, From 0270667f251cf634796e80028d7e7060d26ce293 Mon Sep 17 00:00:00 2001 From: pglorio Date: Sat, 23 Nov 2024 02:15:21 +0000 Subject: [PATCH 26/73] dropped inheritance for Zamba2PreTrainedModel --- .../models/zamba2/modeling_zamba2.py | 23 ---------------- .../models/zamba2/modular_zamba2.py | 26 +++++-------------- 2 files changed, 7 insertions(+), 42 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 2d318d714fc..53411ced2f9 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1545,29 +1545,6 @@ def _init_weights(self, module): module.dt_bias.copy_(inv_dt) module.dt_bias._no_reinit = True - @classmethod - def _check_and_enable_flash_attn_2( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, Dict[str, int]]] = None, - check_device_map: bool = True, - hard_check_only: bool = False, - ): - """ - Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 2 by default on Zamba2 models. - Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba2 v1. - """ - config = super()._check_and_enable_flash_attn_2( - config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map - ) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "flash_attention_2": - config._attn_implementation = "eager" - - return config - ZAMBA2_INPUTS_DOCSTRING = r""" Args: diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 987f37c8a62..6e5fafea64e 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1457,11 +1457,16 @@ def forward( "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", ZAMBA2_START_DOCSTRING, ) -class Zamba2PreTrainedModel(ZambaPreTrainedModel): +class Zamba2PreTrainedModel(PreTrainedModel): config_class = Zamba2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = False _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache - _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] + _is_stateful = True def _init_weights(self, module): std = self.config.initializer_range @@ -1489,23 +1494,6 @@ def _init_weights(self, module): module.dt_bias.copy_(inv_dt) module.dt_bias._no_reinit = True - @classmethod - def _check_and_enable_flash_attn_2( - cls, - config, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, Dict[str, int]]] = None, - check_device_map: bool = True, - hard_check_only: bool = False, - ): - return super(PreTrainedModel, cls)._check_and_enable_flash_attn_2( - config, - torch_dtype=torch_dtype, - device_map=device_map, - check_device_map=check_device_map, - hard_check_only=hard_check_only, - ) - ZAMBA2_INPUTS_DOCSTRING = r""" Args: From 189c8c54571064221b30822fde479d14535a0638 Mon Sep 17 00:00:00 2001 From: pglorio Date: Sat, 23 Nov 2024 02:24:58 +0000 Subject: [PATCH 27/73] make fixup and unit tests --- src/transformers/models/zamba2/modular_zamba2.py | 3 +-- tests/models/zamba2/test_modeling_zamba2.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 6e5fafea64e..b835beb0a65 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -15,7 +15,7 @@ # limitations under the License. import math from itertools import cycle -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.utils.checkpoint @@ -46,7 +46,6 @@ ZambaMambaDecoderLayer, ZambaMLP, ZambaModel, - ZambaPreTrainedModel, ZambaRMSNorm, repeat_kv, ) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 285ae92b2b3..60e302e4253 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -60,7 +60,7 @@ def __init__( use_input_mask=True, use_labels=True, vocab_size=99, - hidden_size=32, + hidden_size=16, mamba_d_state=2, chunk_size=8, mamba_dt_rank="auto", From fa5f79e873dcd3377bfb0bdb859043fef4550fa3 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 5 Dec 2024 00:46:36 +0000 Subject: [PATCH 28/73] Add inheritance of rope from GemmaRotaryEmbedding --- src/transformers/models/zamba2/__init__.py | 42 +----- .../models/zamba2/configuration_zamba2.py | 5 +- .../models/zamba2/modeling_zamba2.py | 31 +++-- .../models/zamba2/modular_zamba2.py | 125 +++--------------- 4 files changed, 45 insertions(+), 158 deletions(-) diff --git a/src/transformers/models/zamba2/__init__.py b/src/transformers/models/zamba2/__init__.py index af01a5f2a64..965db3767c7 100644 --- a/src/transformers/models/zamba2/__init__.py +++ b/src/transformers/models/zamba2/__init__.py @@ -13,45 +13,15 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure -_import_structure = { - "configuration_zamba2": ["Zamba2Config"], -} - - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - pass -else: - _import_structure["modeling_zamba2"] = [ - "Zamba2ForCausalLM", - "Zamba2ForSequenceClassification", - "Zamba2Model", - "Zamba2PreTrainedModel", - ] - if TYPE_CHECKING: - from .configuration_zamba2 import Zamba2Config - - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - pass - else: - from .modeling_zamba2 import ( - Zamba2ForCausalLM, - Zamba2ForSequenceClassification, - Zamba2Model, - Zamba2PreTrainedModel, - ) - - + from .configuration_zamba2 import * + from .modeling_zamba2 import * else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) \ No newline at end of file diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 25546afa52a..c4af0a52d01 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -142,7 +142,7 @@ def __init__( time_step_max=0.1, time_step_floor=1e-4, time_step_limit=None, - n_mamba_heads=1, + n_mamba_heads=8, use_conv_bias=True, chunk_size=256, add_bias_linear=False, @@ -231,3 +231,6 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.num_logits_to_keep = num_logits_to_keep + + +__all__ = ["Zamba2Config"] diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 53411ced2f9..6a6d953d754 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -212,20 +212,19 @@ def reset(self): class Zamba2RotaryEmbedding(nn.Module): - def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() + self.dim = dim self.max_position_embeddings = max_position_embeddings - if config.use_long_context: - a = 8 - base = base * a ** (dim / (dim - 2)) self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() - def forward(self, x, position_ids): + def forward(self, x, position_ids, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts @@ -256,13 +255,11 @@ def layer_type_list(config: Zamba2Config): """ Returns list of layer ids containing hybrid layers """ - ll = [] - i = 0 - for val in config.layers_block_type: - if val == "hybrid": - ll.append(i) - i += 1 - return ll + output_list = [] + for index, type in enumerate(config.layers_block_type): + if type == "hybrid": + output_list.append(index) + return output_list class Zamba2Attention(nn.Module): @@ -323,6 +320,9 @@ def __init__( self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.num_fwd_mem_blocks = num_fwd_mem_blocks self.rope_theta = config.rope_theta + if config.use_long_context: + a = 8 + self.rope_theta = self.rope_theta * a ** (self.head_dim / (self.head_dim - 2)) self.layer_block_map = layer_type_list(config) self.block_id = block_id @@ -2174,3 +2174,6 @@ def forward( hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) + + +__all__ = ["Zamba2ForCausalLM", "Zamba2ForSequenceClassification", "Zamba2Model", "Zamba2PreTrainedModel"] diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index b835beb0a65..1d6df20f32e 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -36,6 +36,7 @@ ) from ..llama.modeling_llama import apply_rotary_pos_emb from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum +from ..gemma.modeling_gemma import GemmaRotaryEmbedding from ..zamba.modeling_zamba import ( ZambaAttention, ZambaAttentionDecoderLayer, @@ -189,7 +190,7 @@ def __init__( time_step_max=0.1, time_step_floor=1e-4, time_step_limit=None, - n_mamba_heads=1, + n_mamba_heads=8, use_conv_bias=True, chunk_size=256, add_bias_linear=False, @@ -303,13 +304,11 @@ def layer_type_list(config: Zamba2Config): """ Returns list of layer ids containing hybrid layers """ - ll = [] - i = 0 - for val in config.layers_block_type: - if val == "hybrid": - ll.append(i) - i += 1 - return ll + output_list = [] + for index, type in enumerate(config.layers_block_type): + if type == "hybrid": + output_list.append(index) + return output_list class Zamba2HybridDynamicCache(ZambaHybridDynamicCache): @@ -379,33 +378,8 @@ def reset(self): self.ssm_states.zero_() -class Zamba2RotaryEmbedding(nn.Module): - def __init__(self, config, dim, max_position_embeddings=4096, base=10000, device=None): - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - if config.use_long_context: - a = 8 - base = base * a ** (dim / (dim - 2)) - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +class Zamba2RotaryEmbedding(GemmaRotaryEmbedding): + pass class Zamba2Attention(ZambaAttention): @@ -433,6 +407,9 @@ def __init__( super().__init__(config, layer_idx) self.num_fwd_mem_blocks = num_fwd_mem_blocks self.rope_theta = config.rope_theta + if config.use_long_context: + a = 8 + self.rope_theta = self.rope_theta * a ** (self.head_dim / (self.head_dim - 2)) self.layer_block_map = layer_type_list(config) self.block_id = block_id self.is_causal = True @@ -1668,71 +1645,8 @@ def __init__(self, config: Zamba2Config): self.post_init() -# Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA -class Zamba2ForCausalLM(ZambaForCausalLM, Zamba2PreTrainedModel, GenerationMixin): - def __init__(self, config: Zamba2Config): - super().__init__(config) - self.model = Zamba2Model(config) - self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys] - - # Initialize weights and apply final processing - self.post_init() - - # Adapted from transformers.models.zamba.modeling_zamba.ZambaForCausalLM.prepare_inputs_for_generation - # with `Zamba2HybridDynamicCache` -> `Zamba2HybridDynamicCache` - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - **kwargs, - ): - # Overwitten -- has a unique cache type, `Zamba2HybridDynamicCache` - - empty_past_kv = past_key_values is None - - # Omit tokens covered by past_key_values - if not empty_past_kv: - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - else: - past_key_values = Zamba2HybridDynamicCache( - self.config, input_ids.shape[0], dtype=self.dtype, device=self.device - ) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if not empty_past_kv: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and empty_past_kv: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "num_logits_to_keep": self.config.num_logits_to_keep, - "cache_position": cache_position, - } - ) - return model_inputs +class Zamba2ForCausalLM(ZambaForCausalLM): + pass @add_start_docstrings( @@ -1750,11 +1664,8 @@ def prepare_inputs_for_generation( """, ZAMBA2_START_DOCSTRING, ) -class Zamba2ForSequenceClassification(ZambaForSequenceClassification, Zamba2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.model = Zamba2Model(config) - self._tied_weights_keys = self.model._tied_weights_keys +class Zamba2ForSequenceClassification(ZambaForSequenceClassification): + pass - # Initialize weights and apply final processing - self.post_init() + +__all__ = ["Zamba2Config", "Zamba2ForCausalLM", "Zamba2ForSequenceClassification", "Zamba2Model", "Zamba2PreTrainedModel",] \ No newline at end of file From 8079ae035614cdf662af92093c3a84887639821a Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 5 Dec 2024 01:46:34 +0000 Subject: [PATCH 29/73] moved rope to model init --- .../models/zamba2/modeling_zamba2.py | 87 +++++- .../models/zamba2/modular_zamba2.py | 267 +++++++++++++++++- 2 files changed, 330 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 6a6d953d754..df4832bac49 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -56,6 +56,7 @@ else: causal_conv1d_update, causal_conv1d_fn = None, None + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Zamba2Config" @@ -319,10 +320,6 @@ def __init__( self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.num_fwd_mem_blocks = num_fwd_mem_blocks - self.rope_theta = config.rope_theta - if config.use_long_context: - a = 8 - self.rope_theta = self.rope_theta * a ** (self.head_dim / (self.head_dim - 2)) self.layer_block_map = layer_type_list(config) self.block_id = block_id @@ -357,11 +354,14 @@ def __init__( self.linear_v_lora_B_list.append(linear_v_lora_B) if config.use_mem_rope: + rope_theta = config.rope_theta + if config.use_long_context: + a = 8 + rope_theta = rope_theta * a ** (config.attention_head_dim / (config.attention_head_dim - 2)) self.rotary_emb = Zamba2RotaryEmbedding( - config, - self.head_dim, + config.attention_head_dim, max_position_embeddings=config.max_position_embeddings, - base=self.rope_theta, + base=rope_theta, ) self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} @@ -376,6 +376,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -410,7 +411,16 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if self.config.use_mem_rope: - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -510,6 +520,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -547,7 +558,16 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if self.config.use_mem_rope: - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -626,6 +646,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -673,7 +694,16 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if self.config.use_mem_rope: - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -1293,6 +1323,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1313,6 +1344,9 @@ def forward( (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. """ hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) hidden_states = self.input_layernorm(hidden_states) @@ -1325,6 +1359,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) @@ -1361,6 +1396,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.LongTensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1377,6 +1413,11 @@ def forward( (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + transformer_hidden_states (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Output of the previous shared transformer layer (if present) of shape `(batch_size, seq_len, embed_dim)`. """ residual = hidden_states @@ -1431,6 +1472,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1449,6 +1491,9 @@ def forward( (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. """ layer_outputs = self.shared_transformer( @@ -1461,6 +1506,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) transformer_hidden_states = layer_outputs[0] @@ -1479,6 +1525,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) if output_attentions: @@ -1629,6 +1676,7 @@ class Zamba2Model(Zamba2PreTrainedModel): def __init__(self, config: Zamba2Config): super().__init__(config) + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -1713,7 +1761,16 @@ def __init__(self, config: Zamba2Config): self._attn_implementation = config._attn_implementation self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + if config.use_mem_rope: + rope_theta = config.rope_theta + if config.use_long_context: + a = 8 + rope_theta = rope_theta * a ** (config.attention_head_dim / (config.attention_head_dim - 2)) + self.rotary_emb = Zamba2RotaryEmbedding( + config.attention_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=rope_theta, + ) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -1779,6 +1836,12 @@ def forward( causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + # create position embeddings to be shared across the decoder layers + if self.config.use_mem_rope: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + else: + position_embeddings = None + all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -1799,6 +1862,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = layer( @@ -1812,6 +1876,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 1d6df20f32e..dca6339789b 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -15,7 +15,7 @@ # limitations under the License. import math from itertools import cycle -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -25,6 +25,7 @@ from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_utils import PreTrainedModel +from ...modeling_outputs import BaseModelOutputWithPast from ...utils import ( add_start_docstrings, is_flash_attn_greater_or_equal_2_10, @@ -406,10 +407,6 @@ def __init__( ): super().__init__(config, layer_idx) self.num_fwd_mem_blocks = num_fwd_mem_blocks - self.rope_theta = config.rope_theta - if config.use_long_context: - a = 8 - self.rope_theta = self.rope_theta * a ** (self.head_dim / (self.head_dim - 2)) self.layer_block_map = layer_type_list(config) self.block_id = block_id self.is_causal = True @@ -443,13 +440,16 @@ def __init__( self.linear_k_lora_B_list.append(linear_k_lora_B) self.linear_v_lora_A_list.append(linear_v_lora_A) self.linear_v_lora_B_list.append(linear_v_lora_B) - + if config.use_mem_rope: + rope_theta = config.rope_theta + if config.use_long_context: + a = 8 + rope_theta = rope_theta * a ** (config.attention_head_dim / (config.attention_head_dim - 2)) self.rotary_emb = Zamba2RotaryEmbedding( - config, - self.head_dim, + config.attention_head_dim, max_position_embeddings=config.max_position_embeddings, - base=self.rope_theta, + base=rope_theta, ) self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} @@ -464,6 +464,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -498,7 +499,16 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if self.config.use_mem_rope: - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -564,6 +574,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, ): bsz, q_len, _ = hidden_states.size() @@ -601,7 +612,16 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if self.config.use_mem_rope: - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -680,6 +700,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -727,7 +748,16 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if self.config.use_mem_rope: - cos, sin = self.rotary_emb(value_states, position_ids) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -1280,6 +1310,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1300,6 +1331,9 @@ def forward( (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. """ hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) hidden_states = self.input_layernorm(hidden_states) @@ -1312,6 +1346,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) @@ -1335,6 +1370,72 @@ def __init__(self, config: Zamba2Config, layer_idx: int): self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx) self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: Optional[torch.Tensor] = None, + layer_idx: int = None, + attention_mask: Optional[torch.Tensor] = None, + causal_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.LongTensor] = None, + transformer_hidden_states: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + transformer_hidden_states (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Output of the previous shared transformer layer (if present) of shape `(batch_size, seq_len, embed_dim)`. + """ + + residual = hidden_states + + # `transformer_hidden_states` is the output from shared transformer + linear layer (see fig. 2 in https://arxiv.org/pdf/2405.16712). + # `transformer_hidden_states` is then added to the input to the mamba layer below (as described in eq. (6) of https://arxiv.org/pdf/2405.16712). + hidden_states = ( + hidden_states + transformer_hidden_states if transformer_hidden_states is not None else hidden_states + ) + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.mamba( + hidden_states=hidden_states, + cache_params=past_key_value, + attention_mask=attention_mask, + ) + + self_attn_weights = None + + # residual connection after mamba + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + class Zamba2HybridLayer(ZambaHybridLayer): def __init__( @@ -1356,6 +1457,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1374,6 +1476,9 @@ def forward( (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. """ layer_outputs = self.shared_transformer( @@ -1386,6 +1491,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) transformer_hidden_states = layer_outputs[0] @@ -1404,6 +1510,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) if output_attentions: @@ -1554,6 +1661,7 @@ class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): def __init__(self, config: Zamba2Config): Zamba2PreTrainedModel.__init__(self, config) + self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -1638,12 +1746,145 @@ def __init__(self, config: Zamba2Config): self._attn_implementation = config._attn_implementation self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - + if config.use_mem_rope: + rope_theta = config.rope_theta + if config.use_long_context: + a = 8 + rope_theta = rope_theta * a ** (config.attention_head_dim / (config.attention_head_dim - 2)) + self.rotary_emb = Zamba2RotaryEmbedding( + config.attention_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=rope_theta, + ) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + original_hidden_states = torch.clone(inputs_embeds) + # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer + + if use_cache and past_key_values is None: + logger.warning_once( + "Zamba2 requires an initialized `Zamba2HybridDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + # create position embeddings to be shared across the decoder layers + if self.config.use_mem_rope: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + else: + position_embeddings = None + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for layer_idx, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + original_hidden_states, + layer_idx, + attention_mask, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = layer( + hidden_states, + original_hidden_states=original_hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + causal_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + if layer_outputs[1] is not None: + # append attentions only of attention layers. Mamba layers return `None` as the attention weights + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + + next_cache = None if not use_cache else past_key_values + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + class Zamba2ForCausalLM(ZambaForCausalLM): pass From d6206ebd695469883a89e26d1fb836b08d807c1c Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 5 Dec 2024 19:57:57 +0000 Subject: [PATCH 30/73] drop del self.self_attn and del self.feed_forward --- src/transformers/models/zamba2/modeling_zamba2.py | 6 +++--- src/transformers/models/zamba2/modular_zamba2.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index df4832bac49..2c859f61de9 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1303,14 +1303,14 @@ def count_mem_blocks_in_config(config: Zamba2Config): class Zamba2AttentionDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None): super().__init__() - self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) - self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - num_gs = count_mem_blocks_in_config(config) self.block_id = block_id + num_gs = count_mem_blocks_in_config(config) self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation]( config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id ) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) + self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index dca6339789b..33e6fbc05fb 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1289,11 +1289,9 @@ def forward(self, hidden_state, layer_idx=None): class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer): def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - num_gs = count_mem_blocks_in_config(config) self.block_id = block_id - del self.self_attn - del self.feed_forward + num_gs = count_mem_blocks_in_config(config) + super().__init__(config, layer_idx) self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation]( config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id ) From cf613b71e11d69e51e118e1045405f20cd9f545d Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 5 Dec 2024 20:28:09 +0000 Subject: [PATCH 31/73] fix tests --- src/transformers/models/zamba2/__init__.py | 2 +- .../models/zamba2/configuration_zamba2.py | 2 +- .../models/zamba2/modular_zamba2.py | 22 ++++++++++++------- tests/models/zamba2/test_modeling_zamba2.py | 2 +- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/zamba2/__init__.py b/src/transformers/models/zamba2/__init__.py index 965db3767c7..00db458c72e 100644 --- a/src/transformers/models/zamba2/__init__.py +++ b/src/transformers/models/zamba2/__init__.py @@ -24,4 +24,4 @@ import sys _file = globals()["__file__"] - sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) \ No newline at end of file + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index c4af0a52d01..3832fbf64b8 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -58,7 +58,7 @@ class Zamba2Config(PretrainedConfig): Minimum clamping value of the `dt_proj.bias` layer initialization. time_step_limit (`tuple`, *optional*): Accepted range of time step values. - n_mamba_heads (`int`, *optional*, defaults to 1): + n_mamba_heads (`int`, *optional*, defaults to 8): Number of heads for the evolution matrices of mamba 2. use_conv_bias (`bool`, *optional*, defaults to `True`): Whether or not to use bias in the convolution layer of the mixer block. diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 33e6fbc05fb..559edec066a 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -22,12 +22,12 @@ from torch import nn from ...configuration_utils import PretrainedConfig -from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_utils import PreTrainedModel from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, + add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -35,9 +35,9 @@ is_causal_conv1d_available, is_mamba_ssm_available, ) +from ..gemma.modeling_gemma import GemmaRotaryEmbedding from ..llama.modeling_llama import apply_rotary_pos_emb from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum -from ..gemma.modeling_gemma import GemmaRotaryEmbedding from ..zamba.modeling_zamba import ( ZambaAttention, ZambaAttentionDecoderLayer, @@ -107,7 +107,7 @@ class Zamba2Config(PretrainedConfig): Minimum clamping value of the `dt_proj.bias` layer initialization. time_step_limit (`tuple`, *optional*): Accepted range of time step values. - n_mamba_heads (`int`, *optional*, defaults to 1): + n_mamba_heads (`int`, *optional*, defaults to 8): Number of heads for the evolution matrices of mamba 2. use_conv_bias (`bool`, *optional*, defaults to `True`): Whether or not to use bias in the convolution layer of the mixer block. @@ -440,7 +440,7 @@ def __init__( self.linear_k_lora_B_list.append(linear_k_lora_B) self.linear_v_lora_A_list.append(linear_v_lora_A) self.linear_v_lora_B_list.append(linear_v_lora_B) - + if config.use_mem_rope: rope_theta = config.rope_theta if config.use_long_context: @@ -1903,8 +1903,14 @@ class Zamba2ForCausalLM(ZambaForCausalLM): """, ZAMBA2_START_DOCSTRING, ) -class Zamba2ForSequenceClassification(ZambaForSequenceClassification): - pass +class Zamba2ForSequenceClassification(ZambaForSequenceClassification): + pass -__all__ = ["Zamba2Config", "Zamba2ForCausalLM", "Zamba2ForSequenceClassification", "Zamba2Model", "Zamba2PreTrainedModel",] \ No newline at end of file +__all__ = [ + "Zamba2Config", + "Zamba2ForCausalLM", + "Zamba2ForSequenceClassification", + "Zamba2Model", + "Zamba2PreTrainedModel", +] diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 60e302e4253..9782a42c899 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -68,7 +68,7 @@ def __init__( num_attention_heads=2, n_mamba_heads=8, mamba_ngroups=8, - intermediate_size=16, + intermediate_size=4, hidden_act="gelu", hidden_mamba_act="silu", hidden_dropout_prob=0.1, From 337faed6afa70c244357c44b2076bd3920728d04 Mon Sep 17 00:00:00 2001 From: pglorio Date: Sat, 7 Dec 2024 03:16:43 +0000 Subject: [PATCH 32/73] renamed lora -> adapter --- .../models/zamba2/modular_zamba2.py | 267 +++++++++--------- 1 file changed, 134 insertions(+), 133 deletions(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 559edec066a..5887326e095 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -132,12 +132,12 @@ class Zamba2Config(PretrainedConfig): The dropout ratio for the attention probabilities. num_mem_blocks (`int`, *optional*, defaults to 1): Number of unshared transformer blocks. - use_shared_mlp_lora (`bool`, *optional*, defaults to `False`): - If True, unshared LoRA's will be added to the shared MLP's. - use_shared_attention_lora (`bool`, *optional*, defaults to `False`): - If True, unshared LoRA's will be added to the q, k, v projectors in the shared attention layers. - lora_rank (`int`, *optional*, defaults to 128): - Rank of the LoRA in the shared MLP and shared attention layers. + use_shared_mlp_adapter (`bool`, *optional*, defaults to `False`): + If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the shared MLP's. + use_shared_attention_adapter (`bool`, *optional*, defaults to `False`): + If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the q, k, v projectors in the shared attention layers. + adapter_rank (`int`, *optional*, defaults to 128): + Rank of the adapter in the shared MLP and shared attention layers. use_mem_rope (`bool`, *optional*, defaults to `False`): If True, includes RoPE in the shared attention layers. rope_theta (`float`, *optional*, defaults to `10000.0`): @@ -201,9 +201,9 @@ def __init__( num_key_value_heads=None, attention_dropout=0.0, num_mem_blocks=1, - use_shared_mlp_lora=False, - use_shared_attention_lora=False, - lora_rank=128, + use_shared_mlp_adapter=False, + use_shared_attention_adapter=False, + adapter_rank=128, use_mem_rope=False, rope_theta=10000, initializer_range=0.02, @@ -248,9 +248,9 @@ def __init__( self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.use_shared_mlp_lora = use_shared_mlp_lora - self.use_shared_attention_lora = use_shared_attention_lora - self.lora_rank = lora_rank + self.use_shared_mlp_adapter = use_shared_mlp_adapter + self.use_shared_attention_adapter = use_shared_attention_adapter + self.adapter_rank = adapter_rank self.use_long_context = use_long_context self.time_step_min = time_step_min self.time_step_max = time_step_max @@ -395,7 +395,8 @@ class Zamba2Attention(ZambaAttention): attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this - layer is tied, un-tied LoRA modules are added to the q, k, v projectors to increase expressivity with a small memory overhead. + layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase + expressivity with a small memory overhead (see Fig. 2 of https://arxiv.org/pdf/2411.15242). """ def __init__( @@ -411,35 +412,35 @@ def __init__( self.block_id = block_id self.is_causal = True - if config.use_shared_attention_lora: - self.linear_q_lora_A_list = nn.ModuleList([]) - self.linear_q_lora_B_list = nn.ModuleList([]) - self.linear_k_lora_A_list = nn.ModuleList([]) - self.linear_k_lora_B_list = nn.ModuleList([]) - self.linear_v_lora_A_list = nn.ModuleList([]) - self.linear_v_lora_B_list = nn.ModuleList([]) + if config.use_shared_attention_adapter: + self.linear_q_adapter_A_list = nn.ModuleList([]) + self.linear_q_adapter_B_list = nn.ModuleList([]) + self.linear_k_adapter_A_list = nn.ModuleList([]) + self.linear_k_adapter_B_list = nn.ModuleList([]) + self.linear_v_adapter_A_list = nn.ModuleList([]) + self.linear_v_adapter_B_list = nn.ModuleList([]) for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: - linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) - linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) - linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) - linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) - linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) - linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + linear_q_adapter_A = nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False) + linear_q_adapter_B = nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False) + linear_k_adapter_A = nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False) + linear_k_adapter_B = nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False) + linear_v_adapter_A = nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False) + linear_v_adapter_B = nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False) else: - linear_q_lora_A = nn.Identity() - linear_q_lora_B = nn.Identity() - linear_k_lora_A = nn.Identity() - linear_k_lora_B = nn.Identity() - linear_v_lora_A = nn.Identity() - linear_v_lora_B = nn.Identity() - self.linear_q_lora_A_list.append(linear_q_lora_A) - self.linear_q_lora_B_list.append(linear_q_lora_B) - self.linear_k_lora_A_list.append(linear_k_lora_A) - self.linear_k_lora_B_list.append(linear_k_lora_B) - self.linear_v_lora_A_list.append(linear_v_lora_A) - self.linear_v_lora_B_list.append(linear_v_lora_B) + linear_q_adapter_A = nn.Identity() + linear_q_adapter_B = nn.Identity() + linear_k_adapter_A = nn.Identity() + linear_k_adapter_B = nn.Identity() + linear_v_adapter_A = nn.Identity() + linear_v_adapter_B = nn.Identity() + self.linear_q_adapter_A_list.append(linear_q_adapter_A) + self.linear_q_adapter_B_list.append(linear_q_adapter_B) + self.linear_k_adapter_A_list.append(linear_k_adapter_A) + self.linear_k_adapter_B_list.append(linear_k_adapter_B) + self.linear_v_adapter_A_list.append(linear_v_adapter_A) + self.linear_v_adapter_B_list.append(linear_v_adapter_B) if config.use_mem_rope: rope_theta = config.rope_theta @@ -469,26 +470,26 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.config.use_shared_attention_lora: - lora_layer_idx = self.layer_dic[layer_idx] - linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx] - linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx] - q_lora_output = linear_q_lora_A(hidden_states) - q_lora_output = linear_q_lora_B(q_lora_output) + if self.config.use_shared_attention_adapter: + adapter_layer_idx = self.layer_dic[layer_idx] + linear_q_adapter_A = self.linear_q_adapter_A_list[adapter_layer_idx] + linear_q_adapter_B = self.linear_q_adapter_B_list[adapter_layer_idx] + q_adapter_output = linear_q_adapter_A(hidden_states) + q_adapter_output = linear_q_adapter_B(q_adapter_output) query_states = self.q_proj(hidden_states) - query_states = query_states + q_lora_output - linear_k_lora_A = self.linear_k_lora_A_list[lora_layer_idx] - linear_k_lora_B = self.linear_k_lora_B_list[lora_layer_idx] - k_lora_output = linear_k_lora_A(hidden_states) - k_lora_output = linear_k_lora_B(k_lora_output) + query_states = query_states + q_adapter_output + linear_k_adapter_A = self.linear_k_adapter_A_list[adapter_layer_idx] + linear_k_adapter_B = self.linear_k_adapter_B_list[adapter_layer_idx] + k_adapter_output = linear_k_adapter_A(hidden_states) + k_adapter_output = linear_k_adapter_B(k_adapter_output) key_states = self.k_proj(hidden_states) - key_states = key_states + k_lora_output - linear_v_lora_A = self.linear_v_lora_A_list[lora_layer_idx] - linear_v_lora_B = self.linear_v_lora_B_list[lora_layer_idx] - v_lora_output = linear_v_lora_A(hidden_states) - v_lora_output = linear_v_lora_B(v_lora_output) + key_states = key_states + k_adapter_output + linear_v_adapter_A = self.linear_v_adapter_A_list[adapter_layer_idx] + linear_v_adapter_B = self.linear_v_adapter_B_list[adapter_layer_idx] + v_adapter_output = linear_v_adapter_A(hidden_states) + v_adapter_output = linear_v_adapter_B(v_adapter_output) value_states = self.v_proj(hidden_states) - value_states = value_states + v_lora_output + value_states = value_states + v_adapter_output else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -579,26 +580,26 @@ def forward( ): bsz, q_len, _ = hidden_states.size() - if self.config.use_shared_attention_lora: + if self.config.use_shared_attention_adapter: layer_idx = self.layer_dic[layer_idx] - linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] - linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] - q_lora_output = linear_q_lora_A(hidden_states) - q_lora_output = linear_q_lora_B(q_lora_output) + linear_q_adapter_A = self.linear_q_adapter_A_list[layer_idx] + linear_q_adapter_B = self.linear_q_adapter_B_list[layer_idx] + q_adapter_output = linear_q_adapter_A(hidden_states) + q_adapter_output = linear_q_adapter_B(q_adapter_output) query_states = self.q_proj(hidden_states) - query_states = query_states + q_lora_output - linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] - linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] - k_lora_output = linear_k_lora_A(hidden_states) - k_lora_output = linear_k_lora_B(k_lora_output) + query_states = query_states + q_adapter_output + linear_k_adapter_A = self.linear_k_adapter_A_list[layer_idx] + linear_k_adapter_B = self.linear_k_adapter_B_list[layer_idx] + k_adapter_output = linear_k_adapter_A(hidden_states) + k_adapter_output = linear_k_adapter_B(k_adapter_output) key_states = self.k_proj(hidden_states) - key_states = key_states + k_lora_output - linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] - linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] - v_lora_output = linear_v_lora_A(hidden_states) - v_lora_output = linear_v_lora_B(v_lora_output) + key_states = key_states + k_adapter_output + linear_v_adapter_A = self.linear_v_adapter_A_list[layer_idx] + linear_v_adapter_B = self.linear_v_adapter_B_list[layer_idx] + v_adapter_output = linear_v_adapter_A(hidden_states) + v_adapter_output = linear_v_adapter_B(v_adapter_output) value_states = self.v_proj(hidden_states) - value_states = value_states + v_lora_output + value_states = value_states + v_adapter_output else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -719,25 +720,25 @@ def forward( bsz, q_len, _ = hidden_states.size() - if self.config.use_shared_attention_lora: - linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] - linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] - q_lora_output = linear_q_lora_A(hidden_states) - q_lora_output = linear_q_lora_B(q_lora_output) + if self.config.use_shared_attention_adapter: + linear_q_adapter_A = self.linear_q_adapter_A_list[layer_idx] + linear_q_adapter_B = self.linear_q_adapter_B_list[layer_idx] + q_adapter_output = linear_q_adapter_A(hidden_states) + q_adapter_output = linear_q_adapter_B(q_adapter_output) query_states = self.q_proj(hidden_states) - query_states = query_states + q_lora_output - linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] - linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] - k_lora_output = linear_k_lora_A(hidden_states) - k_lora_output = linear_k_lora_B(k_lora_output) + query_states = query_states + q_adapter_output + linear_k_adapter_A = self.linear_k_adapter_A_list[layer_idx] + linear_k_adapter_B = self.linear_k_adapter_B_list[layer_idx] + k_adapter_output = linear_k_adapter_A(hidden_states) + k_adapter_output = linear_k_adapter_B(k_adapter_output) key_states = self.k_proj(hidden_states) - key_states = key_states + k_lora_output - linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] - linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] - v_lora_output = linear_v_lora_A(hidden_states) - v_lora_output = linear_v_lora_B(v_lora_output) + key_states = key_states + k_adapter_output + linear_v_adapter_A = self.linear_v_adapter_A_list[layer_idx] + linear_v_adapter_B = self.linear_v_adapter_B_list[layer_idx] + v_adapter_output = linear_v_adapter_A(hidden_states) + v_adapter_output = linear_v_adapter_B(v_adapter_output) value_states = self.v_proj(hidden_states) - value_states = value_states + v_lora_output + value_states = value_states + v_adapter_output else: query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -1235,7 +1236,7 @@ class Zamba2MLP(ZambaMLP): def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int = None): """ This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer - is tied, un-tied LoRA modules are added to the up and gate projectors to increase expressivity with a small memory overhead. + is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead. """ super().__init__(config) self.config = config @@ -1254,31 +1255,31 @@ def gated_act_fn(x): self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) - if self.config.use_shared_mlp_lora: - self.gate_up_proj_lora_A_list = nn.ModuleList([]) - self.gate_up_proj_lora_B_list = nn.ModuleList([]) + if self.config.use_shared_mlp_adapter: + self.gate_up_proj_adapter_A_list = nn.ModuleList([]) + self.gate_up_proj_adapter_B_list = nn.ModuleList([]) for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: - gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias=False) - gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.intermediate_size, bias=False) + gate_up_proj_adapter_A = nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False) + gate_up_proj_adapter_B = nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False) else: - gate_up_proj_lora_A = nn.Identity() - gate_up_proj_lora_B = nn.Identity() - self.gate_up_proj_lora_A_list.append(gate_up_proj_lora_A) - self.gate_up_proj_lora_B_list.append(gate_up_proj_lora_B) + gate_up_proj_adapter_A = nn.Identity() + gate_up_proj_adapter_B = nn.Identity() + self.gate_up_proj_adapter_A_list.append(gate_up_proj_adapter_A) + self.gate_up_proj_adapter_B_list.append(gate_up_proj_adapter_B) layer_block_map = layer_type_list(config) self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx=None): - if self.config.use_shared_mlp_lora: + if self.config.use_shared_mlp_adapter: layer_idx = self.layer_dic[layer_idx] - gate_up_proj_lora_A = self.gate_up_proj_lora_A_list[layer_idx] - gate_up_proj_lora_B = self.gate_up_proj_lora_B_list[layer_idx] - lora_output = gate_up_proj_lora_A(hidden_state) - lora_output = gate_up_proj_lora_B(lora_output) + gate_up_proj_adapter_A = self.gate_up_proj_adapter_A_list[layer_idx] + gate_up_proj_adapter_B = self.gate_up_proj_adapter_B_list[layer_idx] + adapter_output = gate_up_proj_adapter_A(hidden_state) + adapter_output = gate_up_proj_adapter_B(adapter_output) intermediate_state = self.gate_up_proj(hidden_state) - hidden_state = intermediate_state + lora_output + hidden_state = intermediate_state + adapter_output else: hidden_state = self.gate_up_proj(hidden_state) @@ -1695,48 +1696,48 @@ def __init__(self, config: Zamba2Config): "shared_transformer.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - if config.use_shared_mlp_lora: - tied_keys_lora = [] - lora_id = 0 + if config.use_shared_mlp_adapter: + tied_keys_adapter = [] + adapter_id = 0 for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append( - "shared_transformer.feed_forward.gate_up_proj_lora_A_list." - + str(lora_id) + if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_A_list." + + str(adapter_id) + ".weight" ) - tied_keys_lora.append( - "shared_transformer.feed_forward.gate_up_proj_lora_B_list." - + str(lora_id) + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_B_list." + + str(adapter_id) + ".weight" ) - lora_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] - if config.use_shared_attention_lora: - tied_keys_lora = [] - lora_id = 0 + adapter_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + if config.use_shared_attention_adapter: + tied_keys_adapter = [] + adapter_id = 0 for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append( - "shared_transformer.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" + if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_A_list." + str(adapter_id) + ".weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_A_list." + str(adapter_id) + ".weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_A_list." + str(adapter_id) + ".weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_B_list." + str(adapter_id) + ".weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_B_list." + str(adapter_id) + ".weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_B_list." + str(adapter_id) + ".weight" ) - lora_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] + adapter_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) From f1b31a13289255e55f612041dea9df4a467831e6 Mon Sep 17 00:00:00 2001 From: pglorio Date: Sat, 7 Dec 2024 03:56:35 +0000 Subject: [PATCH 33/73] rewrote adapter implementation --- .../models/zamba2/configuration_zamba2.py | 24 +- .../models/zamba2/modeling_zamba2.py | 254 ++++++++---------- .../models/zamba2/modular_zamba2.py | 179 +++++------- 3 files changed, 178 insertions(+), 279 deletions(-) diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 3832fbf64b8..7d1ead551d8 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -83,12 +83,12 @@ class Zamba2Config(PretrainedConfig): The dropout ratio for the attention probabilities. num_mem_blocks (`int`, *optional*, defaults to 1): Number of unshared transformer blocks. - use_shared_mlp_lora (`bool`, *optional*, defaults to `False`): - If True, unshared LoRA's will be added to the shared MLP's. - use_shared_attention_lora (`bool`, *optional*, defaults to `False`): - If True, unshared LoRA's will be added to the q, k, v projectors in the shared attention layers. - lora_rank (`int`, *optional*, defaults to 128): - Rank of the LoRA in the shared MLP and shared attention layers. + use_shared_mlp_adapter (`bool`, *optional*, defaults to `False`): + If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the shared MLP's. + use_shared_attention_adapter (`bool`, *optional*, defaults to `False`): + If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the q, k, v projectors in the shared attention layers. + adapter_rank (`int`, *optional*, defaults to 128): + Rank of the adapter in the shared MLP and shared attention layers. use_mem_rope (`bool`, *optional*, defaults to `False`): If True, includes RoPE in the shared attention layers. rope_theta (`float`, *optional*, defaults to `10000.0`): @@ -152,9 +152,9 @@ def __init__( num_key_value_heads=None, attention_dropout=0.0, num_mem_blocks=1, - use_shared_mlp_lora=False, - use_shared_attention_lora=False, - lora_rank=128, + use_shared_mlp_adapter=False, + use_shared_attention_adapter=False, + adapter_rank=128, use_mem_rope=False, rope_theta=10000, initializer_range=0.02, @@ -199,9 +199,9 @@ def __init__( self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.use_shared_mlp_lora = use_shared_mlp_lora - self.use_shared_attention_lora = use_shared_attention_lora - self.lora_rank = lora_rank + self.use_shared_mlp_adapter = use_shared_mlp_adapter + self.use_shared_attention_adapter = use_shared_attention_adapter + self.adapter_rank = adapter_rank self.use_long_context = use_long_context self.time_step_min = time_step_min self.time_step_max = time_step_max diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 2c859f61de9..7355d8d5ae6 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -286,7 +286,8 @@ class Zamba2Attention(nn.Module): attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2) Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this - layer is tied, un-tied LoRA modules are added to the q, k, v projectors to increase expressivity with a small memory overhead. + layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase + expressivity with a small memory overhead (see Fig. 2 of https://arxiv.org/pdf/2411.15242). """ def __init__( @@ -323,35 +324,32 @@ def __init__( self.layer_block_map = layer_type_list(config) self.block_id = block_id - if config.use_shared_attention_lora: - self.linear_q_lora_A_list = nn.ModuleList([]) - self.linear_q_lora_B_list = nn.ModuleList([]) - self.linear_k_lora_A_list = nn.ModuleList([]) - self.linear_k_lora_B_list = nn.ModuleList([]) - self.linear_v_lora_A_list = nn.ModuleList([]) - self.linear_v_lora_B_list = nn.ModuleList([]) + if config.use_shared_attention_adapter: + self.linear_q_adapter_list = nn.ModuleList([]) + self.linear_k_adapter_list = nn.ModuleList([]) + self.linear_v_adapter_list = nn.ModuleList([]) for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: - linear_q_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) - linear_q_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) - linear_k_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) - linear_k_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) - linear_v_lora_A = nn.Linear(self.attention_hidden_size, self.config.lora_rank, bias=False) - linear_v_lora_B = nn.Linear(self.config.lora_rank, self.attention_hidden_size, bias=False) + linear_q_adapter = nn.Sequential( + nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False), + ) + linear_k_adapter = nn.Sequential( + nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False), + ) + linear_v_adapter = nn.Sequential( + nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False), + ) else: - linear_q_lora_A = nn.Identity() - linear_q_lora_B = nn.Identity() - linear_k_lora_A = nn.Identity() - linear_k_lora_B = nn.Identity() - linear_v_lora_A = nn.Identity() - linear_v_lora_B = nn.Identity() - self.linear_q_lora_A_list.append(linear_q_lora_A) - self.linear_q_lora_B_list.append(linear_q_lora_B) - self.linear_k_lora_A_list.append(linear_k_lora_A) - self.linear_k_lora_B_list.append(linear_k_lora_B) - self.linear_v_lora_A_list.append(linear_v_lora_A) - self.linear_v_lora_B_list.append(linear_v_lora_B) + linear_q_adapter = nn.Identity() + linear_k_adapter = nn.Identity() + linear_v_adapter = nn.Identity() + self.linear_q_adapter_list.append(linear_q_adapter) + self.linear_k_adapter_list.append(linear_k_adapter) + self.linear_v_adapter_list.append(linear_v_adapter) if config.use_mem_rope: rope_theta = config.rope_theta @@ -381,30 +379,14 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() - if self.config.use_shared_attention_lora: - lora_layer_idx = self.layer_dic[layer_idx] - linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx] - linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx] - q_lora_output = linear_q_lora_A(hidden_states) - q_lora_output = linear_q_lora_B(q_lora_output) - query_states = self.q_proj(hidden_states) - query_states = query_states + q_lora_output - linear_k_lora_A = self.linear_k_lora_A_list[lora_layer_idx] - linear_k_lora_B = self.linear_k_lora_B_list[lora_layer_idx] - k_lora_output = linear_k_lora_A(hidden_states) - k_lora_output = linear_k_lora_B(k_lora_output) - key_states = self.k_proj(hidden_states) - key_states = key_states + k_lora_output - linear_v_lora_A = self.linear_v_lora_A_list[lora_layer_idx] - linear_v_lora_B = self.linear_v_lora_B_list[lora_layer_idx] - v_lora_output = linear_v_lora_A(hidden_states) - v_lora_output = linear_v_lora_B(v_lora_output) - value_states = self.v_proj(hidden_states) - value_states = value_states + v_lora_output - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + if self.config.use_shared_attention_adapter: + adapter_layer_idx = self.layer_dic[layer_idx] + query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -525,30 +507,14 @@ def forward( ): bsz, q_len, _ = hidden_states.size() - if self.config.use_shared_attention_lora: - layer_idx = self.layer_dic[layer_idx] - linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] - linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] - q_lora_output = linear_q_lora_A(hidden_states) - q_lora_output = linear_q_lora_B(q_lora_output) - query_states = self.q_proj(hidden_states) - query_states = query_states + q_lora_output - linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] - linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] - k_lora_output = linear_k_lora_A(hidden_states) - k_lora_output = linear_k_lora_B(k_lora_output) - key_states = self.k_proj(hidden_states) - key_states = key_states + k_lora_output - linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] - linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] - v_lora_output = linear_v_lora_A(hidden_states) - v_lora_output = linear_v_lora_B(v_lora_output) - value_states = self.v_proj(hidden_states) - value_states = value_states + v_lora_output - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + if self.config.use_shared_attention_adapter: + adapter_layer_idx = self.layer_dic[layer_idx] + query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -665,29 +631,14 @@ def forward( bsz, q_len, _ = hidden_states.size() - if self.config.use_shared_attention_lora: - linear_q_lora_A = self.linear_q_lora_A_list[layer_idx] - linear_q_lora_B = self.linear_q_lora_B_list[layer_idx] - q_lora_output = linear_q_lora_A(hidden_states) - q_lora_output = linear_q_lora_B(q_lora_output) - query_states = self.q_proj(hidden_states) - query_states = query_states + q_lora_output - linear_k_lora_A = self.linear_k_lora_A_list[layer_idx] - linear_k_lora_B = self.linear_k_lora_B_list[layer_idx] - k_lora_output = linear_k_lora_A(hidden_states) - k_lora_output = linear_k_lora_B(k_lora_output) - key_states = self.k_proj(hidden_states) - key_states = key_states + k_lora_output - linear_v_lora_A = self.linear_v_lora_A_list[layer_idx] - linear_v_lora_B = self.linear_v_lora_B_list[layer_idx] - v_lora_output = linear_v_lora_A(hidden_states) - v_lora_output = linear_v_lora_B(v_lora_output) - value_states = self.v_proj(hidden_states) - value_states = value_states + v_lora_output - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + if self.config.use_shared_attention_adapter: + adapter_layer_idx = self.layer_dic[layer_idx] + query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -1231,7 +1182,7 @@ class Zamba2MLP(nn.Module): def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int = None): """ This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer - is tied, un-tied LoRA modules are added to the up and gate projectors to increase expressivity with a small memory overhead. + is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead. """ super().__init__() self.hidden_size = config.hidden_size @@ -1249,35 +1200,28 @@ def gated_act_fn(x): self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) - if self.config.use_shared_mlp_lora: - self.gate_up_proj_lora_A_list = nn.ModuleList([]) - self.gate_up_proj_lora_B_list = nn.ModuleList([]) + if self.config.use_shared_mlp_adapter: + self.gate_up_proj_adapter_list = nn.ModuleList([]) for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: - gate_up_proj_lora_A = nn.Linear(self.config.hidden_size, self.config.lora_rank, bias=False) - gate_up_proj_lora_B = nn.Linear(self.config.lora_rank, 2 * self.intermediate_size, bias=False) + gate_up_proj_adapter = nn.Sequential( + nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False), + ) else: - gate_up_proj_lora_A = nn.Identity() - gate_up_proj_lora_B = nn.Identity() - self.gate_up_proj_lora_A_list.append(gate_up_proj_lora_A) - self.gate_up_proj_lora_B_list.append(gate_up_proj_lora_B) + gate_up_proj_adapter = nn.Identity() + self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) layer_block_map = layer_type_list(config) self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx=None): - if self.config.use_shared_mlp_lora: + gate_up_state = self.gate_up_proj(hidden_state) + if self.config.use_shared_mlp_adapter: layer_idx = self.layer_dic[layer_idx] - gate_up_proj_lora_A = self.gate_up_proj_lora_A_list[layer_idx] - gate_up_proj_lora_B = self.gate_up_proj_lora_B_list[layer_idx] - lora_output = gate_up_proj_lora_A(hidden_state) - lora_output = gate_up_proj_lora_B(lora_output) - intermediate_state = self.gate_up_proj(hidden_state) - hidden_state = intermediate_state + lora_output - else: - hidden_state = self.gate_up_proj(hidden_state) + gate_up_state += self.gate_up_proj_adapter_list[layer_idx](hidden_state) - hidden_state = self.gated_act_fn(hidden_state) + hidden_state = self.gated_act_fn(gate_up_state) output = self.down_proj(hidden_state) return output @@ -1712,48 +1656,60 @@ def __init__(self, config: Zamba2Config): "shared_transformer.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - if config.use_shared_mlp_lora: - tied_keys_lora = [] - lora_id = 0 + if config.use_shared_mlp_adapter: + tied_keys_adapter = [] + adapter_id = 0 for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append( - "shared_transformer.feed_forward.gate_up_proj_lora_A_list." - + str(lora_id) - + ".weight" + if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + + str(adapter_id) + + ".0.weight" ) - tied_keys_lora.append( - "shared_transformer.feed_forward.gate_up_proj_lora_B_list." - + str(lora_id) - + ".weight" + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + + str(adapter_id) + + ".1.weight" ) - lora_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] - if config.use_shared_attention_lora: - tied_keys_lora = [] - lora_id = 0 + adapter_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + if config.use_shared_attention_adapter: + tied_keys_adapter = [] + adapter_id = 0 for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id: - tied_keys_lora.append( - "shared_transformer.self_attn.linear_q_lora_A_list." + str(lora_id) + ".weight" + if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".0.weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_k_lora_A_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_list." + + str(adapter_id) + + ".0.weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_v_lora_A_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".0.weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_q_lora_B_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".1.weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_k_lora_B_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_list." + + str(adapter_id) + + ".1.weight" ) - tied_keys_lora.append( - "shared_transformer.self_attn.linear_v_lora_B_list." + str(lora_id) + ".weight" + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".1.weight" ) - lora_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora] + adapter_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 5887326e095..d420f35bbab 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -413,34 +413,31 @@ def __init__( self.is_causal = True if config.use_shared_attention_adapter: - self.linear_q_adapter_A_list = nn.ModuleList([]) - self.linear_q_adapter_B_list = nn.ModuleList([]) - self.linear_k_adapter_A_list = nn.ModuleList([]) - self.linear_k_adapter_B_list = nn.ModuleList([]) - self.linear_v_adapter_A_list = nn.ModuleList([]) - self.linear_v_adapter_B_list = nn.ModuleList([]) + self.linear_q_adapter_list = nn.ModuleList([]) + self.linear_k_adapter_list = nn.ModuleList([]) + self.linear_v_adapter_list = nn.ModuleList([]) for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: - linear_q_adapter_A = nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False) - linear_q_adapter_B = nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False) - linear_k_adapter_A = nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False) - linear_k_adapter_B = nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False) - linear_v_adapter_A = nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False) - linear_v_adapter_B = nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False) + linear_q_adapter = nn.Sequential( + nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False), + ) + linear_k_adapter = nn.Sequential( + nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False), + ) + linear_v_adapter = nn.Sequential( + nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False), + ) else: - linear_q_adapter_A = nn.Identity() - linear_q_adapter_B = nn.Identity() - linear_k_adapter_A = nn.Identity() - linear_k_adapter_B = nn.Identity() - linear_v_adapter_A = nn.Identity() - linear_v_adapter_B = nn.Identity() - self.linear_q_adapter_A_list.append(linear_q_adapter_A) - self.linear_q_adapter_B_list.append(linear_q_adapter_B) - self.linear_k_adapter_A_list.append(linear_k_adapter_A) - self.linear_k_adapter_B_list.append(linear_k_adapter_B) - self.linear_v_adapter_A_list.append(linear_v_adapter_A) - self.linear_v_adapter_B_list.append(linear_v_adapter_B) + linear_q_adapter = nn.Identity() + linear_k_adapter = nn.Identity() + linear_v_adapter = nn.Identity() + self.linear_q_adapter_list.append(linear_q_adapter) + self.linear_k_adapter_list.append(linear_k_adapter) + self.linear_v_adapter_list.append(linear_v_adapter) if config.use_mem_rope: rope_theta = config.rope_theta @@ -470,30 +467,14 @@ def forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) if self.config.use_shared_attention_adapter: adapter_layer_idx = self.layer_dic[layer_idx] - linear_q_adapter_A = self.linear_q_adapter_A_list[adapter_layer_idx] - linear_q_adapter_B = self.linear_q_adapter_B_list[adapter_layer_idx] - q_adapter_output = linear_q_adapter_A(hidden_states) - q_adapter_output = linear_q_adapter_B(q_adapter_output) - query_states = self.q_proj(hidden_states) - query_states = query_states + q_adapter_output - linear_k_adapter_A = self.linear_k_adapter_A_list[adapter_layer_idx] - linear_k_adapter_B = self.linear_k_adapter_B_list[adapter_layer_idx] - k_adapter_output = linear_k_adapter_A(hidden_states) - k_adapter_output = linear_k_adapter_B(k_adapter_output) - key_states = self.k_proj(hidden_states) - key_states = key_states + k_adapter_output - linear_v_adapter_A = self.linear_v_adapter_A_list[adapter_layer_idx] - linear_v_adapter_B = self.linear_v_adapter_B_list[adapter_layer_idx] - v_adapter_output = linear_v_adapter_A(hidden_states) - v_adapter_output = linear_v_adapter_B(v_adapter_output) - value_states = self.v_proj(hidden_states) - value_states = value_states + v_adapter_output - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -580,30 +561,14 @@ def forward( ): bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) if self.config.use_shared_attention_adapter: - layer_idx = self.layer_dic[layer_idx] - linear_q_adapter_A = self.linear_q_adapter_A_list[layer_idx] - linear_q_adapter_B = self.linear_q_adapter_B_list[layer_idx] - q_adapter_output = linear_q_adapter_A(hidden_states) - q_adapter_output = linear_q_adapter_B(q_adapter_output) - query_states = self.q_proj(hidden_states) - query_states = query_states + q_adapter_output - linear_k_adapter_A = self.linear_k_adapter_A_list[layer_idx] - linear_k_adapter_B = self.linear_k_adapter_B_list[layer_idx] - k_adapter_output = linear_k_adapter_A(hidden_states) - k_adapter_output = linear_k_adapter_B(k_adapter_output) - key_states = self.k_proj(hidden_states) - key_states = key_states + k_adapter_output - linear_v_adapter_A = self.linear_v_adapter_A_list[layer_idx] - linear_v_adapter_B = self.linear_v_adapter_B_list[layer_idx] - v_adapter_output = linear_v_adapter_A(hidden_states) - v_adapter_output = linear_v_adapter_B(v_adapter_output) - value_states = self.v_proj(hidden_states) - value_states = value_states + v_adapter_output - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + adapter_layer_idx = self.layer_dic[layer_idx] + query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -720,29 +685,14 @@ def forward( bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) if self.config.use_shared_attention_adapter: - linear_q_adapter_A = self.linear_q_adapter_A_list[layer_idx] - linear_q_adapter_B = self.linear_q_adapter_B_list[layer_idx] - q_adapter_output = linear_q_adapter_A(hidden_states) - q_adapter_output = linear_q_adapter_B(q_adapter_output) - query_states = self.q_proj(hidden_states) - query_states = query_states + q_adapter_output - linear_k_adapter_A = self.linear_k_adapter_A_list[layer_idx] - linear_k_adapter_B = self.linear_k_adapter_B_list[layer_idx] - k_adapter_output = linear_k_adapter_A(hidden_states) - k_adapter_output = linear_k_adapter_B(k_adapter_output) - key_states = self.k_proj(hidden_states) - key_states = key_states + k_adapter_output - linear_v_adapter_A = self.linear_v_adapter_A_list[layer_idx] - linear_v_adapter_B = self.linear_v_adapter_B_list[layer_idx] - v_adapter_output = linear_v_adapter_A(hidden_states) - v_adapter_output = linear_v_adapter_B(v_adapter_output) - value_states = self.v_proj(hidden_states) - value_states = value_states + v_adapter_output - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + adapter_layer_idx = self.layer_dic[layer_idx] + query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -1256,34 +1206,27 @@ def gated_act_fn(x): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) if self.config.use_shared_mlp_adapter: - self.gate_up_proj_adapter_A_list = nn.ModuleList([]) - self.gate_up_proj_adapter_B_list = nn.ModuleList([]) + self.gate_up_proj_adapter_list = nn.ModuleList([]) for i in range(self.num_fwd_mem_blocks): if i % config.num_mem_blocks == block_id: - gate_up_proj_adapter_A = nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False) - gate_up_proj_adapter_B = nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False) + gate_up_proj_adapter = nn.Sequential( + nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False), + ) else: - gate_up_proj_adapter_A = nn.Identity() - gate_up_proj_adapter_B = nn.Identity() - self.gate_up_proj_adapter_A_list.append(gate_up_proj_adapter_A) - self.gate_up_proj_adapter_B_list.append(gate_up_proj_adapter_B) + gate_up_proj_adapter = nn.Identity() + self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) layer_block_map = layer_type_list(config) self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx=None): + gate_up_state = self.gate_up_proj(hidden_state) if self.config.use_shared_mlp_adapter: layer_idx = self.layer_dic[layer_idx] - gate_up_proj_adapter_A = self.gate_up_proj_adapter_A_list[layer_idx] - gate_up_proj_adapter_B = self.gate_up_proj_adapter_B_list[layer_idx] - adapter_output = gate_up_proj_adapter_A(hidden_state) - adapter_output = gate_up_proj_adapter_B(adapter_output) - intermediate_state = self.gate_up_proj(hidden_state) - hidden_state = intermediate_state + adapter_output - else: - hidden_state = self.gate_up_proj(hidden_state) + gate_up_state += self.gate_up_proj_adapter_list[layer_idx](hidden_state) - hidden_state = self.gated_act_fn(hidden_state) + hidden_state = self.gated_act_fn(gate_up_state) output = self.down_proj(hidden_state) return output @@ -1702,14 +1645,14 @@ def __init__(self, config: Zamba2Config): for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_A_list." + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + str(adapter_id) - + ".weight" + + ".0.weight" ) tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_B_list." + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + str(adapter_id) - + ".weight" + + ".1.weight" ) adapter_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] @@ -1719,22 +1662,22 @@ def __init__(self, config: Zamba2Config): for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_A_list." + str(adapter_id) + ".weight" + "shared_transformer.self_attn.linear_q_adapter_list." + str(adapter_id) + ".0.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_A_list." + str(adapter_id) + ".weight" + "shared_transformer.self_attn.linear_k_adapter_list." + str(adapter_id) + ".0.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_A_list." + str(adapter_id) + ".weight" + "shared_transformer.self_attn.linear_v_adapter_list." + str(adapter_id) + ".0.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_B_list." + str(adapter_id) + ".weight" + "shared_transformer.self_attn.linear_q_adapter_list." + str(adapter_id) + ".1.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_B_list." + str(adapter_id) + ".weight" + "shared_transformer.self_attn.linear_k_adapter_list." + str(adapter_id) + ".1.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_B_list." + str(adapter_id) + ".weight" + "shared_transformer.self_attn.linear_v_adapter_list." + str(adapter_id) + ".1.weight" ) adapter_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] From 11fdd47aa3f98479c154a43c5aada5b3c870c917 Mon Sep 17 00:00:00 2001 From: pglorio Date: Sat, 7 Dec 2024 04:10:45 +0000 Subject: [PATCH 34/73] fixed tests --- .../models/zamba2/modular_zamba2.py | 34 +++++++++++++------ 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index d420f35bbab..45a1d1fb6a7 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -422,15 +422,15 @@ def __init__( linear_q_adapter = nn.Sequential( nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False), nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False), - ) + ) linear_k_adapter = nn.Sequential( nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False), nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False), - ) + ) linear_v_adapter = nn.Sequential( nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False), nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False), - ) + ) else: linear_q_adapter = nn.Identity() linear_k_adapter = nn.Identity() @@ -1212,7 +1212,7 @@ def gated_act_fn(x): gate_up_proj_adapter = nn.Sequential( nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False), nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False), - ) + ) else: gate_up_proj_adapter = nn.Identity() self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) @@ -1224,7 +1224,7 @@ def forward(self, hidden_state, layer_idx=None): gate_up_state = self.gate_up_proj(hidden_state) if self.config.use_shared_mlp_adapter: layer_idx = self.layer_dic[layer_idx] - gate_up_state += self.gate_up_proj_adapter_list[layer_idx](hidden_state) + gate_up_state += self.gate_up_proj_adapter_list[layer_idx](hidden_state) hidden_state = self.gated_act_fn(gate_up_state) output = self.down_proj(hidden_state) @@ -1662,22 +1662,34 @@ def __init__(self, config: Zamba2Config): for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." + str(adapter_id) + ".0.weight" + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".0.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." + str(adapter_id) + ".0.weight" + "shared_transformer.self_attn.linear_k_adapter_list." + + str(adapter_id) + + ".0.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." + str(adapter_id) + ".0.weight" + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".0.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." + str(adapter_id) + ".1.weight" + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".1.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." + str(adapter_id) + ".1.weight" + "shared_transformer.self_attn.linear_k_adapter_list." + + str(adapter_id) + + ".1.weight" ) tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." + str(adapter_id) + ".1.weight" + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".1.weight" ) adapter_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] From 5d0a5d46c79f83b97b5bf24eb42a9c47750decce Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 19 Dec 2024 00:36:42 +0000 Subject: [PATCH 35/73] Fix torch_forward in mamba2 layer --- .../models/zamba2/modeling_zamba2.py | 67 ++++++++++++++----- .../models/zamba2/modular_zamba2.py | 5 +- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 7355d8d5ae6..e912816fddd 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -213,23 +214,55 @@ def reset(self): class Zamba2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + config: Zamba2Config, + device=None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -237,6 +270,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -989,6 +1027,7 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic ssm_state = cache_params.ssm_states[self.layer_idx].clone() ssm_state = ssm_state.to(hidden_states.device) if cache_params.has_previous_state: + gate = gate.unsqueeze(1) conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] conv_state = torch.roll(conv_state, shifts=-1, dims=-1) # handle batched generation - states are copied through @@ -1170,10 +1209,6 @@ def forward( ): if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) - dtype = hidden_states.dtype - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) @@ -1185,10 +1220,10 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead. """ super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.act_fn = ACT2FN[config.hidden_act] - self.config = config self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 45a1d1fb6a7..84ea3932ef0 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -993,6 +993,7 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic ssm_state = cache_params.ssm_states[self.layer_idx].clone() ssm_state = ssm_state.to(hidden_states.device) if cache_params.has_previous_state: + gate = gate.unsqueeze(1) conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] conv_state = torch.roll(conv_state, shifts=-1, dims=-1) # handle batched generation - states are copied through @@ -1174,10 +1175,6 @@ def forward( ): if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) - dtype = hidden_states.dtype - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) From ef055c90de487771d5ee70f9fac148c752c99c96 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 19 Dec 2024 02:17:54 +0000 Subject: [PATCH 36/73] Fix torch_forward in mamba2 layer --- src/transformers/models/zamba2/modeling_zamba2.py | 9 ++++++++- src/transformers/models/zamba2/modular_zamba2.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index e912816fddd..997f0d9e015 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1016,7 +1016,10 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection - projected_states = self.in_proj(input_states.squeeze(1)) + if cache_params is not None and cache_params.has_previous_state: + projected_states = self.in_proj(input_states.squeeze(1)) + else: + projected_states = self.in_proj(input_states) d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 @@ -1038,6 +1041,10 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding else: + if attention_mask is not None and not torch.all(attention_mask==1): + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) hidden_states = hidden_states.transpose(1,2) conv_state = nn.functional.pad( hidden_states, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 84ea3932ef0..5b74b5dad93 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -982,7 +982,10 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection - projected_states = self.in_proj(input_states.squeeze(1)) + if cache_params is not None and cache_params.has_previous_state: + projected_states = self.in_proj(input_states.squeeze(1)) + else: + projected_states = self.in_proj(input_states) d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 _, _, gate, hidden_states, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 @@ -1004,6 +1007,10 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding else: + if attention_mask is not None and not torch.all(attention_mask==1): + dtype = hidden_states.dtype + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) hidden_states = hidden_states.transpose(1,2) conv_state = nn.functional.pad( hidden_states, From b993a7895e7e42be6a3f9647759a19320f6bd73e Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 19 Dec 2024 04:23:09 +0000 Subject: [PATCH 37/73] Fix torch_forward in mamba2 layer --- src/transformers/models/zamba2/modeling_zamba2.py | 7 +++---- src/transformers/models/zamba2/modular_zamba2.py | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 997f0d9e015..b8ea26296a6 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1019,6 +1019,9 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic if cache_params is not None and cache_params.has_previous_state: projected_states = self.in_proj(input_states.squeeze(1)) else: + if attention_mask is not None and not torch.all(attention_mask==1): + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + input_states = (input_states * attention_mask[:, :, None]).to(dtype) projected_states = self.in_proj(input_states) d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 _, _, gate, hidden_states, dt = projected_states.split( @@ -1041,10 +1044,6 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding else: - if attention_mask is not None and not torch.all(attention_mask==1): - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) hidden_states = hidden_states.transpose(1,2) conv_state = nn.functional.pad( hidden_states, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 5b74b5dad93..651fb6b33b1 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -985,6 +985,9 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic if cache_params is not None and cache_params.has_previous_state: projected_states = self.in_proj(input_states.squeeze(1)) else: + if attention_mask is not None and not torch.all(attention_mask==1): + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + input_states = (input_states * attention_mask[:, :, None]).to(dtype) projected_states = self.in_proj(input_states) d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2 _, _, gate, hidden_states, dt = projected_states.split( @@ -1007,10 +1010,6 @@ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamic hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding else: - if attention_mask is not None and not torch.all(attention_mask==1): - dtype = hidden_states.dtype - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) hidden_states = hidden_states.transpose(1,2) conv_state = nn.functional.pad( hidden_states, From bf93251a4ea09b664e79a9b567ec52445378a54c Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 19 Dec 2024 19:29:53 +0000 Subject: [PATCH 38/73] Dropped adapter in-place sum --- .../models/zamba2/modeling_zamba2.py | 20 +++++++++---------- .../models/zamba2/modular_zamba2.py | 20 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index b8ea26296a6..71a5562427e 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -422,9 +422,9 @@ def forward( value_states = self.v_proj(hidden_states) if self.config.use_shared_attention_adapter: adapter_layer_idx = self.layer_dic[layer_idx] - query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) + query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -550,9 +550,9 @@ def forward( value_states = self.v_proj(hidden_states) if self.config.use_shared_attention_adapter: adapter_layer_idx = self.layer_dic[layer_idx] - query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) + query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -674,9 +674,9 @@ def forward( value_states = self.v_proj(hidden_states) if self.config.use_shared_attention_adapter: adapter_layer_idx = self.layer_dic[layer_idx] - query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) + query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -1260,7 +1260,7 @@ def forward(self, hidden_state, layer_idx=None): gate_up_state = self.gate_up_proj(hidden_state) if self.config.use_shared_mlp_adapter: layer_idx = self.layer_dic[layer_idx] - gate_up_state += self.gate_up_proj_adapter_list[layer_idx](hidden_state) + gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state) hidden_state = self.gated_act_fn(gate_up_state) output = self.down_proj(hidden_state) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 651fb6b33b1..52774a54409 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -472,9 +472,9 @@ def forward( value_states = self.v_proj(hidden_states) if self.config.use_shared_attention_adapter: adapter_layer_idx = self.layer_dic[layer_idx] - query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) + query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -566,9 +566,9 @@ def forward( value_states = self.v_proj(hidden_states) if self.config.use_shared_attention_adapter: adapter_layer_idx = self.layer_dic[layer_idx] - query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) + query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim @@ -690,9 +690,9 @@ def forward( value_states = self.v_proj(hidden_states) if self.config.use_shared_attention_adapter: adapter_layer_idx = self.layer_dic[layer_idx] - query_states += self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states += self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states += self.linear_v_adapter_list[adapter_layer_idx](hidden_states) + query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) + key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) + value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) @@ -1227,7 +1227,7 @@ def forward(self, hidden_state, layer_idx=None): gate_up_state = self.gate_up_proj(hidden_state) if self.config.use_shared_mlp_adapter: layer_idx = self.layer_dic[layer_idx] - gate_up_state += self.gate_up_proj_adapter_list[layer_idx](hidden_state) + gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state) hidden_state = self.gated_act_fn(gate_up_state) output = self.down_proj(hidden_state) From 99708af8505771e07d77cee20fcc7c533e562861 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 19 Dec 2024 19:42:49 +0000 Subject: [PATCH 39/73] removed rope from attention init --- src/transformers/models/zamba2/modeling_zamba2.py | 11 ----------- src/transformers/models/zamba2/modular_zamba2.py | 11 ----------- 2 files changed, 22 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 71a5562427e..96dc47f2315 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -389,17 +389,6 @@ def __init__( self.linear_k_adapter_list.append(linear_k_adapter) self.linear_v_adapter_list.append(linear_v_adapter) - if config.use_mem_rope: - rope_theta = config.rope_theta - if config.use_long_context: - a = 8 - rope_theta = rope_theta * a ** (config.attention_head_dim / (config.attention_head_dim - 2)) - self.rotary_emb = Zamba2RotaryEmbedding( - config.attention_head_dim, - max_position_embeddings=config.max_position_embeddings, - base=rope_theta, - ) - self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} def forward( diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 52774a54409..fe3620f3876 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -439,17 +439,6 @@ def __init__( self.linear_k_adapter_list.append(linear_k_adapter) self.linear_v_adapter_list.append(linear_v_adapter) - if config.use_mem_rope: - rope_theta = config.rope_theta - if config.use_long_context: - a = 8 - rope_theta = rope_theta * a ** (config.attention_head_dim / (config.attention_head_dim - 2)) - self.rotary_emb = Zamba2RotaryEmbedding( - config.attention_head_dim, - max_position_embeddings=config.max_position_embeddings, - base=rope_theta, - ) - self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} def forward( From d9b4a5004af2f53e281038c66ee4f984bb53dafe Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 19 Dec 2024 20:47:05 +0000 Subject: [PATCH 40/73] updated rope --- .../models/zamba2/configuration_zamba2.py | 9 ++++-- .../models/zamba2/modeling_zamba2.py | 15 ++-------- .../models/zamba2/modular_zamba2.py | 28 ++++++++++--------- tests/models/zamba2/test_modeling_zamba2.py | 3 ++ 4 files changed, 27 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 7d1ead551d8..01f1a8f7776 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -184,11 +184,15 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_mem_blocks = num_mem_blocks - self.use_mem_rope = use_mem_rope - self.rope_theta = rope_theta self.attention_hidden_size = 2 * hidden_size self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads self.attention_dropout = attention_dropout + self.use_mem_rope = use_mem_rope + self.use_long_context = use_long_context + if use_mem_rope and use_long_context: + a = 8 + rope_theta = rope_theta * a ** (self.attention_head_dim / (self.attention_head_dim - 2)) + self.rope_theta = rope_theta self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv self.mamba_expand = mamba_expand @@ -202,7 +206,6 @@ def __init__( self.use_shared_mlp_adapter = use_shared_mlp_adapter self.use_shared_attention_adapter = use_shared_attention_adapter self.adapter_rank = adapter_rank - self.use_long_context = use_long_context self.time_step_min = time_step_min self.time_step_max = time_step_max self.time_step_floor = time_step_floor diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 96dc47f2315..207e77e53b6 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -220,7 +220,7 @@ def __init__( device=None, ): super().__init__() - self.rope_kwargs = {} + self.rope_kwargs = {"base": config.rope_theta, "dim": config.attention_head_dim} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -231,8 +231,7 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(config=None, device=device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -1748,15 +1747,7 @@ def __init__(self, config: Zamba2Config): self._attn_implementation = config._attn_implementation self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.use_mem_rope: - rope_theta = config.rope_theta - if config.use_long_context: - a = 8 - rope_theta = rope_theta * a ** (config.attention_head_dim / (config.attention_head_dim - 2)) - self.rotary_emb = Zamba2RotaryEmbedding( - config.attention_head_dim, - max_position_embeddings=config.max_position_embeddings, - base=rope_theta, - ) + self.rotary_emb = Zamba2RotaryEmbedding(config) self.gradient_checkpointing = False # Initialize weights and apply final processing diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index fe3620f3876..c80e6443cf0 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -233,11 +233,15 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_mem_blocks = num_mem_blocks - self.use_mem_rope = use_mem_rope - self.rope_theta = rope_theta self.attention_hidden_size = 2 * hidden_size self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads self.attention_dropout = attention_dropout + self.use_mem_rope = use_mem_rope + self.use_long_context = use_long_context + if use_mem_rope and use_long_context: + a = 8 + rope_theta = rope_theta * a ** (self.attention_head_dim / (self.attention_head_dim - 2)) + self.rope_theta = rope_theta self.mamba_d_state = mamba_d_state self.mamba_d_conv = mamba_d_conv self.mamba_expand = mamba_expand @@ -251,7 +255,6 @@ def __init__( self.use_shared_mlp_adapter = use_shared_mlp_adapter self.use_shared_attention_adapter = use_shared_attention_adapter self.adapter_rank = adapter_rank - self.use_long_context = use_long_context self.time_step_min = time_step_min self.time_step_max = time_step_max self.time_step_floor = time_step_floor @@ -380,7 +383,14 @@ def reset(self): class Zamba2RotaryEmbedding(GemmaRotaryEmbedding): - pass + def __init__( + self, + config: Zamba2Config, + device=None, + ): + super().__init__(config, device) + self.rope_kwargs = {'base': config.rope_theta, 'dim': config.attention_head_dim} + inv_freq, self.attention_scaling = self.rope_init_fn(config=None, device=device, **self.rope_kwargs) class Zamba2Attention(ZambaAttention): @@ -1693,15 +1703,7 @@ def __init__(self, config: Zamba2Config): self._attn_implementation = config._attn_implementation self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.use_mem_rope: - rope_theta = config.rope_theta - if config.use_long_context: - a = 8 - rope_theta = rope_theta * a ** (config.attention_head_dim / (config.attention_head_dim - 2)) - self.rotary_emb = Zamba2RotaryEmbedding( - config.attention_head_dim, - max_position_embeddings=config.max_position_embeddings, - base=rope_theta, - ) + self.rotary_emb = Zamba2RotaryEmbedding(config) self.gradient_checkpointing = False # Initialize weights and apply final processing diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 9782a42c899..a64bf543237 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -82,6 +82,7 @@ def __init__( scope=None, layers_block_type=["mamba", "hybrid"], num_mem_blocks=1, + use_mem_rope=True, ): self.parent = parent self.batch_size = batch_size @@ -112,6 +113,7 @@ def __init__( self.scope = scope self.layers_block_type = layers_block_type self.num_mem_blocks = num_mem_blocks + self.use_mem_rope = use_mem_rope def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) @@ -155,6 +157,7 @@ def get_config(self): use_mamba_kernels=False, layers_block_type=self.layers_block_type, num_mem_blocks=self.num_mem_blocks, + use_mem_rope=self.use_mem_rope, ) def prepare_config_and_inputs_for_decoder(self): From 095d853bf872e9d05036dddf4fd8f011cb8646dd Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 19 Dec 2024 21:01:14 +0000 Subject: [PATCH 41/73] created get_layers method --- .../models/zamba2/modeling_zamba2.py | 154 +++++++++--------- .../models/zamba2/modular_zamba2.py | 34 ++-- 2 files changed, 98 insertions(+), 90 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 207e77e53b6..1fe095aaa90 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1667,81 +1667,7 @@ def __init__(self, config: Zamba2Config): mamba_layers = iter(mamba_layers) linear_layers = iter(linear_layers) blocks = cycle(blocks) - layers = [] - self._tied_weights_keys = [] - for layer_id, layer_type in enumerate(self.layers_block_type): - if layer_type == "hybrid": - block = next(blocks) - if config.num_mem_blocks * len(layer_type_list(config)) > 1: - prefix_name = f"layers.{layer_id}." - tied_keys = [ - "shared_transformer.self_attn.q_proj.weight", - "shared_transformer.self_attn.k_proj.weight", - "shared_transformer.self_attn.v_proj.weight", - "shared_transformer.self_attn.o_proj.weight", - "shared_transformer.feed_forward.gate_up_proj.weight", - "shared_transformer.feed_forward.down_proj.weight", - "shared_transformer.input_layernorm.weight", - "shared_transformer.pre_ff_layernorm.weight", - ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - if config.use_shared_mlp_adapter: - tied_keys_adapter = [] - adapter_id = 0 - for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] - if config.use_shared_attention_adapter: - tied_keys_adapter = [] - adapter_id = 0 - for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] - layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) - else: - layers.append(next(mamba_layers)) + layers = self.get_layers(blocks, linear_layers, mamba_layers) self.layers = nn.ModuleList(layers) self._attn_implementation = config._attn_implementation @@ -1918,6 +1844,84 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): return causal_mask + def get_layers(self, blocks, linear_layers, mamba_layers): + layers = [] + self._tied_weights_keys = [] + for layer_id, layer_type in enumerate(self.layers_block_type): + if layer_type == "hybrid": + block = next(blocks) + if self.config.num_mem_blocks * len(layer_type_list(self.config)) > 1: + prefix_name = f"layers.{layer_id}." + tied_keys = [ + "shared_transformer.self_attn.q_proj.weight", + "shared_transformer.self_attn.k_proj.weight", + "shared_transformer.self_attn.v_proj.weight", + "shared_transformer.self_attn.o_proj.weight", + "shared_transformer.feed_forward.gate_up_proj.weight", + "shared_transformer.feed_forward.down_proj.weight", + "shared_transformer.input_layernorm.weight", + "shared_transformer.pre_ff_layernorm.weight", + ] + self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + if self.config.use_shared_mlp_adapter: + tied_keys_adapter = [] + adapter_id = 0 + for _layer_type in self.layers_block_type: + if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + + str(adapter_id) + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + + str(adapter_id) + + ".1.weight" + ) + adapter_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + if self.config.use_shared_attention_adapter: + tied_keys_adapter = [] + adapter_id = 0 + for _layer_type in self.layers_block_type: + if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_list." + + str(adapter_id) + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".1.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_list." + + str(adapter_id) + + ".1.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".1.weight" + ) + adapter_id += 1 + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) + else: + layers.append(next(mamba_layers)) + return layers + # Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba2, JAMBA->ZAMBA2 class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin): diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index c80e6443cf0..df206663e13 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1623,12 +1623,25 @@ def __init__(self, config: Zamba2Config): mamba_layers = iter(mamba_layers) linear_layers = iter(linear_layers) blocks = cycle(blocks) + layers = self.get_layers(blocks, linear_layers, mamba_layers) + self.layers = nn.ModuleList(layers) + + self._attn_implementation = config._attn_implementation + self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.use_mem_rope: + self.rotary_emb = Zamba2RotaryEmbedding(config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] self._tied_weights_keys = [] for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": block = next(blocks) - if config.num_mem_blocks * len(layer_type_list(config)) > 1: + if self.config.num_mem_blocks * len(layer_type_list(self.config)) > 1: prefix_name = f"layers.{layer_id}." tied_keys = [ "shared_transformer.self_attn.q_proj.weight", @@ -1641,11 +1654,11 @@ def __init__(self, config: Zamba2Config): "shared_transformer.pre_ff_layernorm.weight", ] self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] - if config.use_shared_mlp_adapter: + if self.config.use_shared_mlp_adapter: tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: + if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: tied_keys_adapter.append( "shared_transformer.feed_forward.gate_up_proj_adapter_list." + str(adapter_id) @@ -1658,11 +1671,11 @@ def __init__(self, config: Zamba2Config): ) adapter_id += 1 self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] - if config.use_shared_attention_adapter: + if self.config.use_shared_attention_adapter: tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % config.num_mem_blocks == block.block_id: + if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: tied_keys_adapter.append( "shared_transformer.self_attn.linear_q_adapter_list." + str(adapter_id) @@ -1698,16 +1711,7 @@ def __init__(self, config: Zamba2Config): layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) - self.layers = nn.ModuleList(layers) - - self._attn_implementation = config._attn_implementation - self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - if config.use_mem_rope: - self.rotary_emb = Zamba2RotaryEmbedding(config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() + return layers @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) def forward( From 99e343e6300680d691d0636e3c888227502fefc2 Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 20 Dec 2024 01:24:24 +0000 Subject: [PATCH 42/73] make fixup fix --- src/transformers/models/zamba2/modular_zamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index df206663e13..fd0e4040d46 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -389,7 +389,7 @@ def __init__( device=None, ): super().__init__(config, device) - self.rope_kwargs = {'base': config.rope_theta, 'dim': config.attention_head_dim} + self.rope_kwargs = {"base": config.rope_theta, "dim": config.attention_head_dim} inv_freq, self.attention_scaling = self.rope_init_fn(config=None, device=device, **self.rope_kwargs) From 4e4097579545a978d5c6f662806dabd94ede5745 Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 20 Dec 2024 02:36:53 +0000 Subject: [PATCH 43/73] make fixup fixes --- src/transformers/models/zamba2/modeling_zamba2.py | 4 ++++ src/transformers/models/zamba2/modular_zamba2.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 1fe095aaa90..fea8fb6926e 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1673,6 +1673,10 @@ def __init__(self, config: Zamba2Config): self._attn_implementation = config._attn_implementation self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.use_mem_rope: + if config.use_long_context: + logger.warning_once( + "`use_long_context` set to `True`, using rescaled `rope_theta` and extended `max_position_embeddings`." + ) self.rotary_emb = Zamba2RotaryEmbedding(config) self.gradient_checkpointing = False diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index fd0e4040d46..37303c0964e 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1629,6 +1629,10 @@ def __init__(self, config: Zamba2Config): self._attn_implementation = config._attn_implementation self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.use_mem_rope: + if config.use_long_context: + logger.warning_once( + "`use_long_context` set to `True`, using rescaled `rope_theta` and extended `max_position_embeddings`." + ) self.rotary_emb = Zamba2RotaryEmbedding(config) self.gradient_checkpointing = False From 61bb32fa61b2e694dd822111aa75b250ec3a03ba Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 20 Dec 2024 02:46:37 +0000 Subject: [PATCH 44/73] make fixup fixes --- src/transformers/models/zamba2/modeling_zamba2.py | 2 +- src/transformers/models/zamba2/modular_zamba2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index fea8fb6926e..caca525f879 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1675,7 +1675,7 @@ def __init__(self, config: Zamba2Config): if config.use_mem_rope: if config.use_long_context: logger.warning_once( - "`use_long_context` set to `True`, using rescaled `rope_theta` and extended `max_position_embeddings`." + "`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`." ) self.rotary_emb = Zamba2RotaryEmbedding(config) self.gradient_checkpointing = False diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 37303c0964e..2cf8fcde8cc 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1631,7 +1631,7 @@ def __init__(self, config: Zamba2Config): if config.use_mem_rope: if config.use_long_context: logger.warning_once( - "`use_long_context` set to `True`, using rescaled `rope_theta` and extended `max_position_embeddings`." + "`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`." ) self.rotary_emb = Zamba2RotaryEmbedding(config) self.gradient_checkpointing = False From cb90bb4efbab0243bc86ada0203dacfb60080c8b Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 13 Jan 2025 20:34:54 +0000 Subject: [PATCH 45/73] update to new attention standard --- .../models/zamba2/modeling_zamba2.py | 387 +++--------------- .../models/zamba2/modular_zamba2.py | 300 ++------------ 2 files changed, 92 insertions(+), 595 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index caca525f879..32f6fea233a 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -21,7 +21,7 @@ # limitations under the License. import math from itertools import cycle -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -31,17 +31,12 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_zamba2 import Zamba2Config @@ -289,6 +284,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + def layer_type_list(config: Zamba2Config): """ Returns list of layer ids containing hybrid layers @@ -338,25 +359,18 @@ def __init__( self.config = config self.layer_idx = layer_idx - self.hidden_size = config.hidden_size self.attention_hidden_size = config.attention_hidden_size - self.num_heads = config.num_attention_heads self.head_dim = config.attention_head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings + self.scaling = (self.head_dim / 2) ** -0.5 self.is_causal = True self.attention_dropout = config.attention_dropout - if (self.head_dim * self.num_heads) != self.attention_hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.attention_hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.q_proj = nn.Linear(config.attention_hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.num_fwd_mem_blocks = num_fwd_mem_blocks self.layer_block_map = layer_type_list(config) self.block_id = block_id @@ -397,13 +411,13 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, - output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -414,9 +428,9 @@ def forward( key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) if self.config.use_mem_rope: if position_embeddings is None: @@ -434,291 +448,30 @@ def forward( if past_key_value is not None: key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: -# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward -# dropped use_sliding_windows from the arguments of self._flash_attention_forward -class Zamba2FlashAttention2(Zamba2Attention): - """ - Zamba2 flash attention module. This module inherits from `Zamba2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - layer_idx: int, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - if self.config.use_shared_attention_adapter: - adapter_layer_idx = self.layer_dic[layer_idx] - query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if self.config.use_mem_rope: - if position_embeddings is None: + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - softmax_scale = 1 / math.sqrt(self.head_dim / 2) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=softmax_scale, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: -# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention -class Zamba2SdpaAttention(Zamba2Attention): - """ - Zamba2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Zamba2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - def forward( - self, - hidden_states: torch.Tensor, - layer_idx: int, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Zamba2Model is using Zamba2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - if self.config.use_shared_attention_adapter: - adapter_layer_idx = self.layer_dic[layer_idx] - query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if self.config.use_mem_rope: - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - softmax_scale = 1 / math.sqrt(self.head_dim / 2) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - scale=softmax_scale, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value + return attn_output, attn_weights # Helper methods for segment sum computation @@ -1266,21 +1019,12 @@ def count_mem_blocks_in_config(config: Zamba2Config): return num_gs -ZAMBA2_ATTENTION_CLASSES = { - "eager": Zamba2Attention, - "flash_attention_2": Zamba2FlashAttention2, - "sdpa": Zamba2SdpaAttention, -} - - class Zamba2AttentionDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None): super().__init__() self.block_id = block_id num_gs = count_mem_blocks_in_config(config) - self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation]( - config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id - ) + self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1297,7 +1041,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1323,7 +1067,7 @@ def forward( """ hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, @@ -1344,9 +1088,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1801,17 +1542,13 @@ def forward( if past_key_values and not past_key_values.has_previous_state: past_key_values.has_previous_state = True - next_cache = None if not use_cache else past_key_values - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if self.config._attn_implementation == "flash_attention_2": diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 2cf8fcde8cc..ec2184f2c32 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -15,16 +15,17 @@ # limitations under the License. import math from itertools import cycle -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ...configuration_utils import PretrainedConfig -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -50,6 +51,7 @@ ZambaModel, ZambaRMSNorm, repeat_kv, + eager_attention_forward, ) @@ -420,7 +422,6 @@ def __init__( self.num_fwd_mem_blocks = num_fwd_mem_blocks self.layer_block_map = layer_type_list(config) self.block_id = block_id - self.is_causal = True if config.use_shared_attention_adapter: self.linear_q_adapter_list = nn.ModuleList([]) @@ -458,13 +459,13 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, - output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -475,9 +476,9 @@ def forward( key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) if self.config.use_mem_rope: if position_embeddings is None: @@ -495,264 +496,30 @@ def forward( if past_key_value is not None: key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: -# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward -# dropped use_sliding_windows from the arguments of self._flash_attention_forward -class Zamba2FlashAttention2(Zamba2Attention): - """ - Zamba2 flash attention module. This module inherits from `Zamba2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - layer_idx: int, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - if self.config.use_shared_attention_adapter: - adapter_layer_idx = self.layer_dic[layer_idx] - query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if self.config.use_mem_rope: - if position_embeddings is None: + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) - cos, sin = self.rotary_emb(value_states, position_ids) else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - softmax_scale = 1 / math.sqrt(self.head_dim / 2) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=softmax_scale, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention: -# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention -class Zamba2SdpaAttention(Zamba2Attention): - """ - Zamba2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Zamba2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - def forward( - self, - hidden_states: torch.Tensor, - layer_idx: int, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Zamba2Model is using Zamba2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - if self.config.use_shared_attention_adapter: - adapter_layer_idx = self.layer_dic[layer_idx] - query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states) - key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states) - value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if self.config.use_mem_rope: - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - softmax_scale = 1 / math.sqrt(self.head_dim / 2) - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - scale=softmax_scale, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -ZAMBA2_ATTENTION_CLASSES = { - "eager": Zamba2Attention, - "flash_attention_2": Zamba2FlashAttention2, - "sdpa": Zamba2SdpaAttention, -} + return attn_output, attn_weights class Zamba2MambaMixer(nn.Module): @@ -1238,7 +1005,7 @@ def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Option self.block_id = block_id num_gs = count_mem_blocks_in_config(config) super().__init__(config, layer_idx) - self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = Zamba2Attention( config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id ) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) @@ -1255,7 +1022,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1281,7 +1048,7 @@ def forward( """ hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, @@ -1302,9 +1069,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs @@ -1829,17 +1593,13 @@ def forward( if past_key_values and not past_key_values.has_previous_state: past_key_values.has_previous_state = True - next_cache = None if not use_cache else past_key_values - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() class Zamba2ForCausalLM(ZambaForCausalLM): From 1dbc8c73e54e71345cfe676eea704aa862b7fc84 Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 13 Jan 2025 22:01:23 +0000 Subject: [PATCH 46/73] update to new attention standard --- .../models/zamba2/modeling_zamba2.py | 41 +++++++++++++++++-- .../models/zamba2/modular_zamba2.py | 38 ++++++++++++++++- src/transformers/testing_utils.py | 1 + 3 files changed, 75 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 32f6fea233a..5226bb3fcbf 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -238,13 +238,14 @@ def _dynamic_frequency_update(self, position_ids, device): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @@ -321,6 +322,40 @@ def layer_type_list(config: Zamba2Config): return output_list +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + class Zamba2Attention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index ec2184f2c32..5210fc81857 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -37,7 +37,7 @@ is_mamba_ssm_available, ) from ..gemma.modeling_gemma import GemmaRotaryEmbedding -from ..llama.modeling_llama import apply_rotary_pos_emb +# from ..llama.modeling_llama import apply_rotary_pos_emb from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum from ..zamba.modeling_zamba import ( ZambaAttention, @@ -317,6 +317,40 @@ def layer_type_list(config: Zamba2Config): return output_list +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + class Zamba2HybridDynamicCache(ZambaHybridDynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache @@ -390,8 +424,8 @@ def __init__( config: Zamba2Config, device=None, ): - super().__init__(config, device) self.rope_kwargs = {"base": config.rope_theta, "dim": config.attention_head_dim} + super().__init__(config, device) inv_freq, self.attention_scaling = self.rope_init_fn(config=None, device=device, **self.rope_kwargs) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 89587d303eb..039e4d254f4 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -1413,6 +1413,7 @@ def set_model_tester_for_less_flaky_test(test_case): # TODO (if possible): Avoid exceptional cases exceptional_classes = [ "ZambaModelTester", + "Zamba2ModelTester", "RwkvModelTester", "AriaVisionText2TextModelTester", "GPTNeoModelTester", From f24e45252ee284d87caefabfffbda770e8efa63c Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 13 Jan 2025 22:10:35 +0000 Subject: [PATCH 47/73] make fixup fixes --- src/transformers/models/zamba2/modular_zamba2.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 5210fc81857..c86042ffe42 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -29,7 +29,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, ) from ...utils.import_utils import ( @@ -37,6 +36,7 @@ is_mamba_ssm_available, ) from ..gemma.modeling_gemma import GemmaRotaryEmbedding + # from ..llama.modeling_llama import apply_rotary_pos_emb from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum from ..zamba.modeling_zamba import ( @@ -50,7 +50,6 @@ ZambaMLP, ZambaModel, ZambaRMSNorm, - repeat_kv, eager_attention_forward, ) @@ -1039,9 +1038,7 @@ def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Option self.block_id = block_id num_gs = count_mem_blocks_in_config(config) super().__init__(config, layer_idx) - self.self_attn = Zamba2Attention( - config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id - ) + self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) def forward( From 2b29338bd2c4f1becb36caa65199e5d398c2b2a0 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 16 Jan 2025 07:32:27 +0000 Subject: [PATCH 48/73] minor fixes --- .../models/zamba2/modeling_zamba2.py | 23 ++++----------- .../models/zamba2/modular_zamba2.py | 27 ++++-------------- tests/models/zamba2/test_modeling_zamba2.py | 5 ++-- utils/modular_model_converter.textClipping | Bin 259 -> 0 bytes 4 files changed, 14 insertions(+), 41 deletions(-) delete mode 100644 utils/modular_model_converter.textClipping diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 5226bb3fcbf..cb3204b5b2f 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -468,16 +468,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) if self.config.use_mem_rope: - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -1005,17 +996,11 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) self.act_fn = ACT2FN[config.hidden_act] self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id - - def gated_act_fn(x): - x = torch.chunk(x, 2, dim=-1) - return self.act_fn(x[0]) * x[1] - - self.gated_act_fn = gated_act_fn self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) if self.config.use_shared_mlp_adapter: self.gate_up_proj_adapter_list = nn.ModuleList([]) @@ -1038,7 +1023,8 @@ def forward(self, hidden_state, layer_idx=None): layer_idx = self.layer_dic[layer_idx] gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state) - hidden_state = self.gated_act_fn(gate_up_state) + gate_up_state = torch.chunk(gate_up_state, 2, dim=-1) + hidden_state = self.act_fn(gate_up_state[0]) * gate_up_state[1] output = self.down_proj(hidden_state) return output @@ -1311,6 +1297,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flex_attn = True _supports_sdpa = False _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache _is_stateful = True diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index c86042ffe42..1d9f29d5ca2 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -35,7 +35,7 @@ is_causal_conv1d_available, is_mamba_ssm_available, ) -from ..gemma.modeling_gemma import GemmaRotaryEmbedding +from ..llama.modeling_llama import LlamaRotaryEmbedding # from ..llama.modeling_llama import apply_rotary_pos_emb from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum @@ -417,7 +417,7 @@ def reset(self): self.ssm_states.zero_() -class Zamba2RotaryEmbedding(GemmaRotaryEmbedding): +class Zamba2RotaryEmbedding(LlamaRotaryEmbedding): def __init__( self, config: Zamba2Config, @@ -514,16 +514,7 @@ def forward( value_states = value_states.view(hidden_shape).transpose(1, 2) if self.config.use_mem_rope: - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -991,19 +982,11 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead. """ super().__init__(config) - self.config = config self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id - def gated_act_fn(x): - x = torch.chunk(x, 2, dim=-1) - return self.act_fn(x[0]) * x[1] - - self.gated_act_fn = gated_act_fn - del self.gate_proj del self.up_proj - del self.down_proj self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) @@ -1028,7 +1011,8 @@ def forward(self, hidden_state, layer_idx=None): layer_idx = self.layer_dic[layer_idx] gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state) - hidden_state = self.gated_act_fn(gate_up_state) + gate_up_state = torch.chunk(gate_up_state, 2, dim=-1) + hidden_state = self.act_fn(gate_up_state[0]) * gate_up_state[1] output = self.down_proj(hidden_state) return output @@ -1286,6 +1270,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flex_attn = True _supports_sdpa = False _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache _is_stateful = True diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index a64bf543237..6fba50bdfb0 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -331,7 +331,7 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) - def test_for_casual_lm(self): + def test_for_causal_lm(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) @@ -376,12 +376,13 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @unittest.skip(reason="Cumbersome and redundant for Zamba2") def test_mismatched_shapes_have_properly_initialized_weights(self): r""" Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the Mamba block are initialized differently and we tested that in test_initialization """ - self.skipTest("Cumbersome and redundant for Zamba2") + pass def test_attention_outputs(self): r""" diff --git a/utils/modular_model_converter.textClipping b/utils/modular_model_converter.textClipping deleted file mode 100644 index 93a7b0661be5b2448eccefc075b9734d8e0b7aab..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 259 zcmZ|EzY4-I5XbSWe=9B>L`t)dRz za9@6WQlsRF;`rvZ_PgySTqEJV-RbuFJ_}}C=MfsCL_`)dNm3W6!W?;M6v`qbaV8dw zjZ2l}k)z}C2PPkwFTNxRCb`a>Ld&WO#kej?VM$o_SCSygK|=-(6d+h&&}>m{2E4KY aN)VM${r%x+x;nVa73%Z6rZ9N*oyQYNeo5y5 From b212cb28cd2430b01580237c7fdacb48a7a516d4 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 16 Jan 2025 08:50:17 +0000 Subject: [PATCH 49/73] cache_position --- src/transformers/models/zamba2/modeling_zamba2.py | 15 +++++++++++++-- src/transformers/models/zamba2/modular_zamba2.py | 13 ++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index cb3204b5b2f..12ba9b26e23 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1297,7 +1297,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_flex_attn = True + _supports_flex_attn = False _supports_sdpa = False _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache _is_stateful = True @@ -1500,7 +1500,15 @@ def forward( ) if cache_position is None: - cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + # cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + past_seen_tokens = ( + past_key_values.get_seq_length(layer_idx=self.first_transformer_layer_id) + if past_key_values is not None + else 0 + ) + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1610,8 +1618,11 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] self._tied_weights_keys = [] + self.first_transformer_layer_id = 0 for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": + if self.first_transformer_layer_id == 0: + self.first_transformer_layer_id = layer_id block = next(blocks) if self.config.num_mem_blocks * len(layer_type_list(self.config)) > 1: prefix_name = f"layers.{layer_id}." diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 1d9f29d5ca2..56282011d34 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1422,8 +1422,11 @@ def __init__(self, config: Zamba2Config): def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] self._tied_weights_keys = [] + self.first_transformer_layer_id = 0 for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": + if self.first_transformer_layer_id == 0: + self.first_transformer_layer_id = layer_id block = next(blocks) if self.config.num_mem_blocks * len(layer_type_list(self.config)) > 1: prefix_name = f"layers.{layer_id}." @@ -1545,7 +1548,15 @@ def forward( ) if cache_position is None: - cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + # cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + past_seen_tokens = ( + past_key_values.get_seq_length(layer_idx=self.first_transformer_layer_id) + if past_key_values is not None + else 0 + ) + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) From 1e3b51e5c08283a1a4e954506f8d07d27bb07f9a Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 16 Jan 2025 09:25:38 +0000 Subject: [PATCH 50/73] removed cache_position postion_ids use_cache --- .../models/zamba2/modeling_zamba2.py | 40 +++---------------- .../models/zamba2/modular_zamba2.py | 38 ++---------------- 2 files changed, 9 insertions(+), 69 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 12ba9b26e23..a5aebef9d12 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -444,11 +444,8 @@ def forward( hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -631,7 +628,6 @@ def cuda_kernels_forward( self, hidden_states: torch.Tensor, cache_params: Optional[Zamba2HybridDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): # set up dimensions for reshapes later @@ -779,7 +775,7 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamicCache]=None, attention_mask: Optional[torch.Tensor]=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection @@ -977,13 +973,12 @@ def forward( self, hidden_states, cache_params: Optional[Zamba2HybridDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: - return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask) - return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.torch_forward(hidden_states, cache_params, attention_mask) class Zamba2MLP(nn.Module): @@ -1056,11 +1051,8 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -1080,8 +1072,6 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. @@ -1092,11 +1082,8 @@ def forward( hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) @@ -1126,11 +1113,9 @@ def forward( layer_idx: int = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.LongTensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -1146,8 +1131,6 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. @@ -1202,11 +1185,9 @@ def forward( layer_idx: int = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1224,8 +1205,6 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. @@ -1236,11 +1215,8 @@ def forward( original_hidden_states=original_hidden_states, layer_idx=layer_idx, attention_mask=causal_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, position_embeddings=position_embeddings, ) @@ -1255,11 +1231,9 @@ def forward( hidden_states, transformer_hidden_states=transformer_hidden_states, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, position_embeddings=position_embeddings, ) @@ -1297,7 +1271,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_flex_attn = False + _supports_flex_attn = True _supports_sdpa = False _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache _is_stateful = True @@ -1535,11 +1509,9 @@ def forward( layer_idx, attention_mask, causal_mask, - position_ids, past_key_values, output_attentions, use_cache, - cache_position, position_embeddings, ) else: @@ -1549,11 +1521,9 @@ def forward( layer_idx=layer_idx, attention_mask=attention_mask, causal_mask=causal_mask, - position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 56282011d34..eb5fb8aa062 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -490,11 +490,8 @@ def forward( hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -620,7 +617,6 @@ def cuda_kernels_forward( self, hidden_states: torch.Tensor, cache_params: Optional[Zamba2HybridDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): # set up dimensions for reshapes later @@ -768,7 +764,7 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamicCache]=None, attention_mask: Optional[torch.Tensor]=None): batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # Gated MLP's linear projection @@ -966,13 +962,12 @@ def forward( self, hidden_states, cache_params: Optional[Zamba2HybridDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: - return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask) - return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + return self.torch_forward(hidden_states, cache_params, attention_mask) class Zamba2MLP(ZambaMLP): @@ -1031,11 +1026,8 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -1055,8 +1047,6 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. @@ -1067,11 +1057,8 @@ def forward( hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) @@ -1100,11 +1087,9 @@ def forward( layer_idx: int = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.LongTensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -1120,8 +1105,6 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. @@ -1175,11 +1158,9 @@ def forward( layer_idx: int = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1197,8 +1178,6 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. @@ -1209,11 +1188,8 @@ def forward( original_hidden_states=original_hidden_states, layer_idx=layer_idx, attention_mask=causal_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, position_embeddings=position_embeddings, ) @@ -1228,11 +1204,9 @@ def forward( hidden_states, transformer_hidden_states=transformer_hidden_states, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, position_embeddings=position_embeddings, ) @@ -1583,11 +1557,9 @@ def forward( layer_idx, attention_mask, causal_mask, - position_ids, past_key_values, output_attentions, use_cache, - cache_position, position_embeddings, ) else: @@ -1597,11 +1569,9 @@ def forward( layer_idx=layer_idx, attention_mask=attention_mask, causal_mask=causal_mask, - position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] From 5ace701e4c7e5931d459de0f96d2c3c46291ccda Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 16 Jan 2025 19:58:44 +0000 Subject: [PATCH 51/73] remove config from modular --- .../models/zamba2/modular_zamba2.py | 215 +----------------- 1 file changed, 1 insertion(+), 214 deletions(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index eb5fb8aa062..16446d616be 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -52,6 +52,7 @@ ZambaRMSNorm, eager_attention_forward, ) +from .configuration_zamba2 import Zamba2Config if is_mamba_ssm_available(): @@ -73,219 +74,6 @@ logger = logging.get_logger(__name__) -class Zamba2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Zamba2Model`]. It is used to instantiate a - Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the Zamba2 model. - - [Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B) - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Zamba2Model`] - max_position_embeddings (`int`, *optional*, defaults to 4096): - The maximum sequence length that this model might ever be used with. - hidden_size (`int`, *optional*, defaults to 2560): - Dimension of the hidden representations. - num_hidden_layers (`int`, *optional*, defaults to 54): - Number of hidden layers in the model. - layers_block_type (`list`, *optional*): - List of layer types, which can be either "mamba" or "hybrid". - mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents. - mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel. - mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. - mamba_ngroups (`int`, *optional*, defaults to 1): - Number of groups for the evolution matrices of mamba 2. - time_step_min (`float`, *optional*, defaults to 0.001): - Minimum `time_step` used to bound `dt_proj.bias`. - time_step_max (`float`, *optional*, defaults to 0.1): - Maximum `time_step` used to bound `dt_proj.bias`. - time_step_floor (`float`, *optional*, defaults to 0.0001): - Minimum clamping value of the `dt_proj.bias` layer initialization. - time_step_limit (`tuple`, *optional*): - Accepted range of time step values. - n_mamba_heads (`int`, *optional*, defaults to 8): - Number of heads for the evolution matrices of mamba 2. - use_conv_bias (`bool`, *optional*, defaults to `True`): - Whether or not to use bias in the convolution layer of the mixer block. - chunk_size (`int`, *optional*, defaults to 256): - Size of the chunks that will comprise the sequence. - add_bias_linear (`bool`, *optional*, defaults to `False`): - Flag indicating whether or not to use bias in various layers - intermediate_size (`int`, *optional*, defaults to 4 * hidden_size): - Dimension of the MLP representations. - hidden_act (`str`, *optional*, defaults to `"gelu"`): - The non-linear activation function (function or string) in the MLP. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - num_mem_blocks (`int`, *optional*, defaults to 1): - Number of unshared transformer blocks. - use_shared_mlp_adapter (`bool`, *optional*, defaults to `False`): - If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the shared MLP's. - use_shared_attention_adapter (`bool`, *optional*, defaults to `False`): - If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the q, k, v projectors in the shared attention layers. - adapter_rank (`int`, *optional*, defaults to 128): - Rank of the adapter in the shared MLP and shared attention layers. - use_mem_rope (`bool`, *optional*, defaults to `False`): - If True, includes RoPE in the shared attention layers. - rope_theta (`float`, *optional*, defaults to `10000.0`): - The base period of the RoPE embeddings. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): - Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an - integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the - logits of the last prompt token are needed for generation. For long sequences, the logits for the entire - sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint - significantly. - pad_token_id (`int`, *optional*, defaults to 0): - The id of the padding token. - bos_token_id (`int`, *optional*, defaults to 1): - The id of the "beginning-of-sequence" token. - eos_token_id (`int`, *optional*, defaults to 2): - The id of the "end-of-sequence" token. - use_long_context (`bool`, *optional*, defaults to `False`): - Activates the context-extended version of Zamba by modifying RoPE. - ```python - >>> from transformers import Zamba2Model, Zamba2Config - >>> # Initializing a Zamba2-2.7B style configuration - >>> configuration = Zamba2Config() - >>> # Initializing a model from the Zamba2-2.7B style configuration - >>> model = Zamba2Model(configuration) - >>> # Accessing the model configuration - >>> configuration = model.config - """ - - model_type = "zamba2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - max_position_embeddings=4096, - hidden_size=2560, - num_hidden_layers=54, - layers_block_type=None, - mamba_d_state=64, - mamba_d_conv=4, - mamba_expand=2, - mamba_ngroups=1, - time_step_min=0.001, - time_step_max=0.1, - time_step_floor=1e-4, - time_step_limit=None, - n_mamba_heads=8, - use_conv_bias=True, - chunk_size=256, - add_bias_linear=False, - intermediate_size=None, - hidden_act="gelu", - num_attention_heads=32, - num_key_value_heads=None, - attention_dropout=0.0, - num_mem_blocks=1, - use_shared_mlp_adapter=False, - use_shared_attention_adapter=False, - adapter_rank=128, - use_mem_rope=False, - rope_theta=10000, - initializer_range=0.02, - rms_norm_eps=1e-5, - use_cache=True, - num_logits_to_keep=1, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - use_long_context=False, - **kwargs, - ): - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - **kwargs, - ) - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - if intermediate_size is None: - self.intermediate_size = 4 * hidden_size - else: - self.intermediate_size = intermediate_size - self.hidden_act = hidden_act - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.num_mem_blocks = num_mem_blocks - self.attention_hidden_size = 2 * hidden_size - self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads - self.attention_dropout = attention_dropout - self.use_mem_rope = use_mem_rope - self.use_long_context = use_long_context - if use_mem_rope and use_long_context: - a = 8 - rope_theta = rope_theta * a ** (self.attention_head_dim / (self.attention_head_dim - 2)) - self.rope_theta = rope_theta - self.mamba_d_state = mamba_d_state - self.mamba_d_conv = mamba_d_conv - self.mamba_expand = mamba_expand - self.add_bias_linear = add_bias_linear - self.mamba_ngroups = mamba_ngroups - self.n_mamba_heads = n_mamba_heads - self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads - self.use_conv_bias = use_conv_bias - self.chunk_size = chunk_size - self.time_step_limit = time_step_limit - self.use_shared_mlp_adapter = use_shared_mlp_adapter - self.use_shared_attention_adapter = use_shared_attention_adapter - self.adapter_rank = adapter_rank - self.time_step_min = time_step_min - self.time_step_max = time_step_max - self.time_step_floor = time_step_floor - if use_long_context: - self.max_position_embeddings = 16384 - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.num_attention_heads = num_attention_heads - self.kv_channels = self.hidden_size // self.num_attention_heads - self.num_query_groups = self.num_attention_heads - # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer) - if layers_block_type is None: - self.layers_block_type = ( - ["mamba"] - + (["mamba"] * 5 + ["hybrid"]) * 7 - + ["mamba"] * 4 - + ["hybrid"] - + ["mamba"] * 3 - + ["hybrid"] - + ["mamba"] * 2 - ) - else: - self.layers_block_type = layers_block_type - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.num_logits_to_keep = num_logits_to_keep - - class Zamba2RMSNormGated(MambaRMSNormGated): pass @@ -1623,7 +1411,6 @@ class Zamba2ForSequenceClassification(ZambaForSequenceClassification): __all__ = [ - "Zamba2Config", "Zamba2ForCausalLM", "Zamba2ForSequenceClassification", "Zamba2Model", From 535b6319966b3ce53525cb86e8df0f6971bac464 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 16 Jan 2025 20:02:58 +0000 Subject: [PATCH 52/73] removed config from modular (2) --- src/transformers/models/zamba2/modular_zamba2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 16446d616be..512fa5f6ec0 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -21,7 +21,6 @@ import torch.utils.checkpoint from torch import nn -from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel From 1c92266d3735ff2c8ce9cbc64dab5a7db8434e49 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 16 Jan 2025 20:06:54 +0000 Subject: [PATCH 53/73] import apply_rotary_pos_emb from llama --- .../models/zamba2/modeling_zamba2.py | 36 +++++++++--------- .../models/zamba2/modular_zamba2.py | 38 +------------------ 2 files changed, 19 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index a5aebef9d12..ae54c6ccee2 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -311,24 +311,6 @@ def eager_attention_forward( return attn_output, attn_weights -def layer_type_list(config: Zamba2Config): - """ - Returns list of layer ids containing hybrid layers - """ - output_list = [] - for index, type in enumerate(config.layers_block_type): - if type == "hybrid": - output_list.append(index) - return output_list - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -356,6 +338,24 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def layer_type_list(config: Zamba2Config): + """ + Returns list of layer ids containing hybrid layers + """ + output_list = [] + for index, type in enumerate(config.layers_block_type): + if type == "hybrid": + output_list.append(index) + return output_list + + class Zamba2Attention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 512fa5f6ec0..ff8b1349029 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -34,9 +34,7 @@ is_causal_conv1d_available, is_mamba_ssm_available, ) -from ..llama.modeling_llama import LlamaRotaryEmbedding - -# from ..llama.modeling_llama import apply_rotary_pos_emb +from ..llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum from ..zamba.modeling_zamba import ( ZambaAttention, @@ -103,40 +101,6 @@ def layer_type_list(config: Zamba2Config): return output_list -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - class Zamba2HybridDynamicCache(ZambaHybridDynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache From 99bde9383fdcaa1abb156be0fc2ea440c6d8c59f Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 16 Jan 2025 20:11:40 +0000 Subject: [PATCH 54/73] fixed rope_kwargs --- src/transformers/models/zamba2/modeling_zamba2.py | 6 ++++-- src/transformers/models/zamba2/modular_zamba2.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index ae54c6ccee2..75d83e3f92e 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -215,7 +215,6 @@ def __init__( device=None, ): super().__init__() - self.rope_kwargs = {"base": config.rope_theta, "dim": config.attention_head_dim} # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) @@ -226,7 +225,10 @@ def __init__( self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(config=None, device=device, **self.rope_kwargs) + # we cannot use the config here to parameterize because of a factor 2 for the head_dim + inv_freq, self.attention_scaling = self.rope_init_fn( + device=device, base=config.rope_theta, dim=config.attention_head_dim + ) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index ff8b1349029..3daab7dea8a 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -174,9 +174,11 @@ def __init__( config: Zamba2Config, device=None, ): - self.rope_kwargs = {"base": config.rope_theta, "dim": config.attention_head_dim} super().__init__(config, device) - inv_freq, self.attention_scaling = self.rope_init_fn(config=None, device=device, **self.rope_kwargs) + # we cannot use the config here to parameterize because of a factor 2 for the head_dim + inv_freq, self.attention_scaling = self.rope_init_fn( + device=device, base=config.rope_theta, dim=config.attention_head_dim + ) class Zamba2Attention(ZambaAttention): From baf2ed3f86a0d841e493ab023c210b7124b0f4de Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 16 Jan 2025 20:47:17 +0000 Subject: [PATCH 55/73] Instantiate cache in Zamba2Model --- src/transformers/models/zamba2/modeling_zamba2.py | 6 ++---- src/transformers/models/zamba2/modular_zamba2.py | 6 ++---- tests/models/zamba2/test_modeling_zamba2.py | 9 +++++++++ 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 75d83e3f92e..782607b1294 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1470,10 +1470,8 @@ def forward( # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer if use_cache and past_key_values is None: - logger.warning_once( - "Zamba2 requires an initialized `Zamba2HybridDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = Zamba2HybridDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) if cache_position is None: # cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 3daab7dea8a..8fb25cf23d7 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -1269,10 +1269,8 @@ def forward( # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer if use_cache and past_key_values is None: - logger.warning_once( - "Zamba2 requires an initialized `Zamba2HybridDynamicCache` to return a cache. None was " - "provided, so no cache will be returned." - ) + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + past_key_values = Zamba2HybridDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) if cache_position is None: # cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 6fba50bdfb0..93792ac02a1 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -324,6 +324,15 @@ def setUp(self): def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): pass + @unittest.skip("Zamba2 has a hybrid cache") + def test_past_key_values_format(self): + r""" + Zamba2's cache shape depends on whether a given layer is mamba or attention. + For mamba layers, the KV cache has shape is empty and has shape [batch_size, 0]. + The shape checks of this test assume instead that every layer has an attention cache, so we skip it. + """ + pass + def test_config(self): self.config_tester.run_common_tests() From 9afb57ec0dd2e877c7977a188caaaa0b64dc3115 Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 17 Jan 2025 04:55:12 +0000 Subject: [PATCH 56/73] fix cache --- src/transformers/models/zamba2/modeling_zamba2.py | 3 +-- src/transformers/models/zamba2/modular_zamba2.py | 9 ++++++++- tests/models/zamba2/test_modeling_zamba2.py | 8 ++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 782607b1294..a8d7575a69f 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -180,7 +180,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # take any layer that contains cache and not empty tensor layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx - if len(self.key_cache) <= layer_idx: + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: return 0 return self.key_cache[layer_idx].shape[-2] @@ -1474,7 +1474,6 @@ def forward( past_key_values = Zamba2HybridDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) if cache_position is None: - # cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) past_seen_tokens = ( past_key_values.get_seq_length(layer_idx=self.first_transformer_layer_id) if past_key_values is not None diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 8fb25cf23d7..0e5e48cc6e1 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -167,6 +167,14 @@ def reset(self): self.conv_states.zero_() self.ssm_states.zero_() + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0: + return 0 + return self.key_cache[layer_idx].shape[-2] + class Zamba2RotaryEmbedding(LlamaRotaryEmbedding): def __init__( @@ -1273,7 +1281,6 @@ def forward( past_key_values = Zamba2HybridDynamicCache(self.config, batch_size, dtype=self.dtype, device=self.device) if cache_position is None: - # cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) past_seen_tokens = ( past_key_values.get_seq_length(layer_idx=self.first_transformer_layer_id) if past_key_values is not None diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 93792ac02a1..c6fd1a9ce5d 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -333,6 +333,10 @@ def test_past_key_values_format(self): """ pass + @unittest.skip(reason="A large mamba2 would be necessary (and costly) for that") + def test_multi_gpu_data_parallel_forward(self): + pass + def test_config(self): self.config_tester.run_common_tests() @@ -512,7 +516,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): @require_torch_gpu @require_bitsandbytes @pytest.mark.flash_attn_test - @slow + # @slow def test_flash_attn_2_fp32_ln(self): r""" Overriding the test_flash_attn_2_fp32_ln test as the Zamba2 model, like Mixtral, doesn't support @@ -550,7 +554,7 @@ def test_flash_attn_2_fp32_ln(self): @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test - @slow + # @slow def test_flash_attn_2_inference_equivalence_right_padding(self): r""" Overriding the test_flash_attn_2_inference_padding_right test as the Zamba2 model, like Mixtral, doesn't support From d1687f910f89702e337c0347c812a87e88e8acef Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 17 Jan 2025 04:59:49 +0000 Subject: [PATCH 57/73] fix @slow decorator --- src/transformers/models/zamba2/modeling_zamba2.py | 14 +++++++------- tests/models/zamba2/test_modeling_zamba2.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index a8d7575a69f..814eef56a48 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -313,6 +313,13 @@ def eager_attention_forward( return attn_output, attn_weights +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -340,13 +347,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def layer_type_list(config: Zamba2Config): """ Returns list of layer ids containing hybrid layers diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index c6fd1a9ce5d..bd3df9f6e59 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -516,7 +516,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): @require_torch_gpu @require_bitsandbytes @pytest.mark.flash_attn_test - # @slow + @slow def test_flash_attn_2_fp32_ln(self): r""" Overriding the test_flash_attn_2_fp32_ln test as the Zamba2 model, like Mixtral, doesn't support @@ -554,7 +554,7 @@ def test_flash_attn_2_fp32_ln(self): @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test - # @slow + @slow def test_flash_attn_2_inference_equivalence_right_padding(self): r""" Overriding the test_flash_attn_2_inference_padding_right test as the Zamba2 model, like Mixtral, doesn't support From 903f6dc6afe3453229f876b01e1b0528b0853f09 Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 21 Jan 2025 08:45:43 +0000 Subject: [PATCH 58/73] small fix in modular file --- src/transformers/models/zamba2/modeling_zamba2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 814eef56a48..4e1a1a55c89 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -55,7 +55,8 @@ logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "Zamba2Config" + +_CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B" class Zamba2RMSNormGated(torch.nn.Module): From 14396d748258f4f5d0a1080350bbf2d9d1c3c91b Mon Sep 17 00:00:00 2001 From: pglorio <85982602+pglorio@users.noreply.github.com> Date: Wed, 22 Jan 2025 22:41:10 -0800 Subject: [PATCH 59/73] Update docs/source/en/model_doc/zamba2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/model_doc/zamba2.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/zamba2.md b/docs/source/en/model_doc/zamba2.md index 75333555d45..4aa2a437ead 100644 --- a/docs/source/en/model_doc/zamba2.md +++ b/docs/source/en/model_doc/zamba2.md @@ -31,10 +31,9 @@ Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space m ### Presequities -Zamba2 requires you use `transformers` version 4.46.0 or higher: +Zamba2 requires you use `transformers` version 4.48.0 or higher: ```bash -pip install transformers>=4.46.0 -``` +pip install transformers>=4.48.0 ## Inference From 02f58079962e1924e01844d0c9581241695d5a58 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 23 Jan 2025 07:31:07 +0000 Subject: [PATCH 60/73] several minor fixes --- .../models/zamba/modeling_zamba.py | 3 - .../models/zamba2/configuration_zamba2.py | 1 + .../models/zamba2/modeling_zamba2.py | 66 ++++------- .../models/zamba2/modular_zamba2.py | 109 +++--------------- 4 files changed, 39 insertions(+), 140 deletions(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 761c799bdcd..9f8830c81f0 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -271,7 +271,6 @@ def forward( layer_idx: int, attention_mask: Optional[torch.Tensor], past_key_value: Optional[ZambaHybridDynamicCache] = None, - cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -624,7 +623,6 @@ def forward( past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -658,7 +656,6 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, **kwargs, ) # feed-forward (MLP) diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 01f1a8f7776..6e66bb53dea 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -234,6 +234,7 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.num_logits_to_keep = num_logits_to_keep + self.hybrid_layer_ids = [index for index, type in enumerate(self.layers_block_type) if type == "hybrid"] __all__ = ["Zamba2Config"] diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 4e1a1a55c89..92d147a198b 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -125,23 +125,19 @@ def __init__( self._modules = {} self._parameters = {} self._buffers = {} - self.conv_states = { - i: torch.zeros( + self.conv_states = {} + self.ssm_states = {} + for i in range(config.num_hidden_layers): + self.conv_states[i] = torch.zeros( batch_size, self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, self.conv_kernel_size, device=device, dtype=dtype, ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( + self.ssm_states[i] = torch.zeros( batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype ) - for i in range(config.num_hidden_layers) - } - for i in range(config.num_hidden_layers): if self.layers_block_type[i] == "hybrid": self.transformer_layers.append(i) self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] @@ -348,17 +344,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -def layer_type_list(config: Zamba2Config): - """ - Returns list of layer ids containing hybrid layers - """ - output_list = [] - for index, type in enumerate(config.layers_block_type): - if type == "hybrid": - output_list.append(index) - return output_list - - class Zamba2Attention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -410,7 +395,7 @@ def __init__( self.v_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.num_fwd_mem_blocks = num_fwd_mem_blocks - self.layer_block_map = layer_type_list(config) + self.layer_block_map = config.hybrid_layer_ids self.block_id = block_id if config.use_shared_attention_adapter: @@ -994,11 +979,12 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) - self.act_fn = ACT2FN[config.hidden_act] self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) + self.act_fn = ACT2FN[config.hidden_act] if self.config.use_shared_mlp_adapter: self.gate_up_proj_adapter_list = nn.ModuleList([]) @@ -1012,7 +998,7 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int gate_up_proj_adapter = nn.Identity() self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) - layer_block_map = layer_type_list(config) + layer_block_map = config.hybrid_layer_ids self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx=None): @@ -1027,22 +1013,11 @@ def forward(self, hidden_state, layer_idx=None): return output -def count_mem_blocks_in_config(config: Zamba2Config): - """ - Count number of shared blocks - """ - num_gs = 0 - for val in config.layers_block_type: - if val == "hybrid": - num_gs += 1 - return num_gs - - class Zamba2AttentionDecoderLayer(nn.Module): def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None): super().__init__() self.block_id = block_id - num_gs = count_mem_blocks_in_config(config) + num_gs = len(config.hybrid_layer_ids) self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) @@ -1054,9 +1029,10 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, - position_embeddings: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1066,8 +1042,10 @@ def forward( This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The concatenated tensor is then used as input of the pre-attention RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712). + layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba2's tied transformer layers. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. + position_ids (`torch.LongTensor`, *optional*): token positions of shape `(batch, seq_len)`. Used for positional encodings. past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -1075,9 +1053,8 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. """ hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) hidden_states = self.input_layernorm(hidden_states) @@ -1085,14 +1062,15 @@ def forward( hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, + position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, - position_embeddings=position_embeddings, + use_cache=use_cache, **kwargs, ) - + # feed-forward (MLP) hidden_states = self.pre_ff_layernorm(hidden_states) - hidden_states = self.feed_forward(hidden_states, layer_idx) + hidden_states = self.feed_forward(hidden_states) outputs = (hidden_states,) @@ -1594,7 +1572,7 @@ def get_layers(self, blocks, linear_layers, mamba_layers): if self.first_transformer_layer_id == 0: self.first_transformer_layer_id = layer_id block = next(blocks) - if self.config.num_mem_blocks * len(layer_type_list(self.config)) > 1: + if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: prefix_name = f"layers.{layer_id}." tied_keys = [ "shared_transformer.self_attn.q_proj.weight", diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 0e5e48cc6e1..c56bf9d552e 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -21,6 +21,7 @@ import torch.utils.checkpoint from torch import nn +from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -44,7 +45,6 @@ ZambaHybridDynamicCache, ZambaHybridLayer, ZambaMambaDecoderLayer, - ZambaMLP, ZambaModel, ZambaRMSNorm, eager_attention_forward, @@ -79,28 +79,6 @@ class Zamba2RMSNorm(ZambaRMSNorm): pass -def count_mem_blocks_in_config(config: Zamba2Config): - """ - Count number of shared blocks - """ - num_gs = 0 - for val in config.layers_block_type: - if val == "hybrid": - num_gs += 1 - return num_gs - - -def layer_type_list(config: Zamba2Config): - """ - Returns list of layer ids containing hybrid layers - """ - output_list = [] - for index, type in enumerate(config.layers_block_type): - if type == "hybrid": - output_list.append(index) - return output_list - - class Zamba2HybridDynamicCache(ZambaHybridDynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache @@ -129,23 +107,19 @@ def __init__( self._modules = {} self._parameters = {} self._buffers = {} - self.conv_states = { - i: torch.zeros( + self.conv_states = {} + self.ssm_states = {} + for i in range(config.num_hidden_layers): + self.conv_states[i] = torch.zeros( batch_size, self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state, self.conv_kernel_size, device=device, dtype=dtype, ) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros( + self.ssm_states[i] = torch.zeros( batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype ) - for i in range(config.num_hidden_layers) - } - for i in range(config.num_hidden_layers): if self.layers_block_type[i] == "hybrid": self.transformer_layers.append(i) self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] @@ -214,7 +188,7 @@ def __init__( ): super().__init__(config, layer_idx) self.num_fwd_mem_blocks = num_fwd_mem_blocks - self.layer_block_map = layer_type_list(config) + self.layer_block_map = config.hybrid_layer_ids self.block_id = block_id if config.use_shared_attention_adapter: @@ -731,20 +705,22 @@ def forward( return self.torch_forward(hidden_states, cache_params, attention_mask) -class Zamba2MLP(ZambaMLP): +class Zamba2MLP(nn.Module): def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int = None): """ This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead. """ - super().__init__(config) + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size self.num_fwd_mem_blocks = num_fwd_mem_blocks self.block_id = block_id - del self.gate_proj - del self.up_proj self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) + self.act_fn = ACT2FN[config.hidden_act] if self.config.use_shared_mlp_adapter: self.gate_up_proj_adapter_list = nn.ModuleList([]) @@ -758,7 +734,7 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int gate_up_proj_adapter = nn.Identity() self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) - layer_block_map = layer_type_list(config) + layer_block_map = config.hybrid_layer_ids self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx=None): @@ -776,64 +752,11 @@ def forward(self, hidden_state, layer_idx=None): class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer): def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None): self.block_id = block_id - num_gs = count_mem_blocks_in_config(config) + num_gs = len(config.hybrid_layer_ids) super().__init__(config, layer_idx) self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) - def forward( - self, - hidden_states: torch.Tensor, - original_hidden_states: torch.Tensor, - layer_idx: int, - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, - output_attentions: Optional[bool] = False, - position_embeddings: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` - original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. - This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The - concatenated tensor is then used as input of the pre-attention RMSNorm - (see fig. 2 in https://arxiv.org/pdf/2405.16712). - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - """ - hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) - hidden_states = self.input_layernorm(hidden_states) - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - layer_idx=layer_idx, - attention_mask=attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - position_embeddings=position_embeddings, - **kwargs, - ) - - hidden_states = self.pre_ff_layernorm(hidden_states) - hidden_states = self.feed_forward(hidden_states, layer_idx) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - class Zamba2MambaDecoderLayer(ZambaMambaDecoderLayer): def __init__(self, config: Zamba2Config, layer_idx: int): @@ -1163,7 +1086,7 @@ def get_layers(self, blocks, linear_layers, mamba_layers): if self.first_transformer_layer_id == 0: self.first_transformer_layer_id = layer_id block = next(blocks) - if self.config.num_mem_blocks * len(layer_type_list(self.config)) > 1: + if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: prefix_name = f"layers.{layer_id}." tied_keys = [ "shared_transformer.self_attn.q_proj.weight", From bfb026750ec1277407ceda71fbe9fc873cd4e87f Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 23 Jan 2025 08:06:24 +0000 Subject: [PATCH 61/73] inherit mamba2decoder fwd and drop position_ids in mamba --- .../models/zamba/modeling_zamba.py | 10 +-- .../models/zamba2/modeling_zamba2.py | 13 ++-- .../models/zamba2/modular_zamba2.py | 62 ------------------- 3 files changed, 5 insertions(+), 80 deletions(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 9f8830c81f0..e62a48f541b 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -619,7 +619,6 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -635,7 +634,6 @@ def forward( layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - position_ids (`torch.LongTensor`, *optional*): token positions of shape `(batch, seq_len)`. Used for positional encodings. past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -652,7 +650,6 @@ def forward( hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -684,12 +681,12 @@ def forward( layer_idx: int = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -752,7 +749,6 @@ def forward( layer_idx: int = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -782,7 +778,6 @@ def forward( original_hidden_states=original_hidden_states, layer_idx=layer_idx, attention_mask=causal_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -800,7 +795,6 @@ def forward( hidden_states, transformer_hidden_states=transformer_hidden_states, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1104,7 +1098,6 @@ def forward( layer_idx, attention_mask, causal_mask, - position_ids, past_key_values, output_attentions, use_cache, @@ -1117,7 +1110,6 @@ def forward( layer_idx=layer_idx, attention_mask=attention_mask, causal_mask=causal_mask, - position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 92d147a198b..57c7f9ba8d7 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1029,7 +1029,6 @@ def forward( original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -1045,7 +1044,6 @@ def forward( layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba2's tied transformer layers. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - position_ids (`torch.LongTensor`, *optional*): token positions of shape `(batch, seq_len)`. Used for positional encodings. past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -1062,7 +1060,6 @@ def forward( hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, - position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, @@ -1097,8 +1094,9 @@ def forward( past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - position_embeddings: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, transformer_hidden_states: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -1112,11 +1110,8 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - transformer_hidden_states (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Output of the previous shared transformer layer (if present) of shape `(batch_size, seq_len, embed_dim)`. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. """ residual = hidden_states diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index c56bf9d552e..23a5221dae1 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -764,68 +764,6 @@ def __init__(self, config: Zamba2Config, layer_idx: int): self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx) self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def forward( - self, - hidden_states: torch.Tensor, - original_hidden_states: Optional[torch.Tensor] = None, - layer_idx: int = None, - attention_mask: Optional[torch.Tensor] = None, - causal_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - position_embeddings: Optional[torch.LongTensor] = None, - transformer_hidden_states: Optional[torch.Tensor] = None, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - transformer_hidden_states (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Output of the previous shared transformer layer (if present) of shape `(batch_size, seq_len, embed_dim)`. - """ - - residual = hidden_states - - # `transformer_hidden_states` is the output from shared transformer + linear layer (see fig. 2 in https://arxiv.org/pdf/2405.16712). - # `transformer_hidden_states` is then added to the input to the mamba layer below (as described in eq. (6) of https://arxiv.org/pdf/2405.16712). - hidden_states = ( - hidden_states + transformer_hidden_states if transformer_hidden_states is not None else hidden_states - ) - hidden_states = self.input_layernorm(hidden_states) - - hidden_states = self.mamba( - hidden_states=hidden_states, - cache_params=past_key_value, - attention_mask=attention_mask, - ) - - self_attn_weights = None - - # residual connection after mamba - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (past_key_value,) - - return outputs - class Zamba2HybridLayer(ZambaHybridLayer): def __init__( From b2229430a3cae1e66f979464ddd42eda056728a5 Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 23 Jan 2025 08:17:29 +0000 Subject: [PATCH 62/73] removed docstrings from modular --- .../models/zamba2/modeling_zamba2.py | 38 +++--- .../models/zamba2/modular_zamba2.py | 112 ------------------ 2 files changed, 17 insertions(+), 133 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 57c7f9ba8d7..d090741850b 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1219,27 +1219,6 @@ def forward( return layer_outputs -ZAMBA2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Zamba2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", - ZAMBA2_START_DOCSTRING, -) class Zamba2PreTrainedModel(PreTrainedModel): config_class = Zamba2Config base_model_prefix = "model" @@ -1279,6 +1258,23 @@ def _init_weights(self, module): module.dt_bias._no_reinit = True +ZAMBA2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Zamba2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + ZAMBA2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 23a5221dae1..684172352fc 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -27,8 +27,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, logging, ) from ...utils.import_utils import ( @@ -838,27 +836,6 @@ def forward( return layer_outputs -ZAMBA2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Zamba2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", - ZAMBA2_START_DOCSTRING, -) class Zamba2PreTrainedModel(PreTrainedModel): config_class = Zamba2Config base_model_prefix = "model" @@ -898,79 +875,6 @@ def _init_weights(self, module): module.dt_bias._no_reinit = True -ZAMBA2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Zamba2HybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - A Zamba2HybridDynamicCache object containing pre-computed hidden-states (keys and values in the - self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. - Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and - `(batch_size, d_inner, d_state)` respectively. - See the `Zamba2HybridDynamicCache` class for more details. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.", - ZAMBA2_START_DOCSTRING, -) class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): """ Model consisting of *config.num_hidden_layers* layers. @@ -1096,7 +1000,6 @@ def get_layers(self, blocks, linear_layers, mamba_layers): layers.append(next(mamba_layers)) return layers - @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, @@ -1222,21 +1125,6 @@ class Zamba2ForCausalLM(ZambaForCausalLM): pass -@add_start_docstrings( - """ - The Zamba2 Model with a sequence classification head on top (linear layer). - - [`Zamba2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - ZAMBA2_START_DOCSTRING, -) class Zamba2ForSequenceClassification(ZambaForSequenceClassification): pass From 929ee67bc3665373887ac1e41455a182b44e958b Mon Sep 17 00:00:00 2001 From: pglorio Date: Thu, 23 Jan 2025 08:50:26 +0000 Subject: [PATCH 63/73] reinstate zamba2 attention decoder fwd --- .../models/zamba2/modeling_zamba2.py | 14 ++--- .../models/zamba2/modular_zamba2.py | 53 +++++++++++++++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index d090741850b..2148d0ce910 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1031,7 +1031,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, + position_embeddings: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1041,7 +1041,6 @@ def forward( This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The concatenated tensor is then used as input of the pre-attention RMSNorm (see fig. 2 in https://arxiv.org/pdf/2405.16712). - layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba2's tied transformer layers. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states @@ -1051,8 +1050,9 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. """ hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) hidden_states = self.input_layernorm(hidden_states) @@ -1062,12 +1062,12 @@ def forward( attention_mask=attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, - use_cache=use_cache, + position_embeddings=position_embeddings, **kwargs, ) - # feed-forward (MLP) + hidden_states = self.pre_ff_layernorm(hidden_states) - hidden_states = self.feed_forward(hidden_states) + hidden_states = self.feed_forward(hidden_states, layer_idx) outputs = (hidden_states,) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 684172352fc..8a4e379b3ce 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -755,6 +755,59 @@ def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Option self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) + def forward( + self, + hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor, + layer_idx: int, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Zamba2HybridDynamicCache] = None, + output_attentions: Optional[bool] = False, + position_embeddings: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)` + original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`. + This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The + concatenated tensor is then used as input of the pre-attention RMSNorm + (see fig. 2 in https://arxiv.org/pdf/2405.16712). + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + """ + hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1) + hidden_states = self.input_layernorm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + layer_idx=layer_idx, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.pre_ff_layernorm(hidden_states) + hidden_states = self.feed_forward(hidden_states, layer_idx) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + class Zamba2MambaDecoderLayer(ZambaMambaDecoderLayer): def __init__(self, config: Zamba2Config, layer_idx: int): From 9007a522b1f831df6d516a281c0d3fdd20a118f5 Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 24 Jan 2025 01:05:04 +0000 Subject: [PATCH 64/73] use regex for tied keys --- .../models/zamba2/modeling_zamba2.py | 95 +++++++++---------- .../models/zamba2/modular_zamba2.py | 95 +++++++++---------- 2 files changed, 86 insertions(+), 104 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 2148d0ce910..7c418beae9f 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import re from itertools import cycle from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -1558,81 +1559,71 @@ def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] self._tied_weights_keys = [] self.first_transformer_layer_id = 0 + for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": if self.first_transformer_layer_id == 0: self.first_transformer_layer_id = layer_id + block = next(blocks) + + # If there are multiple shared blocks, we gather tied keys if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - prefix_name = f"layers.{layer_id}." - tied_keys = [ - "shared_transformer.self_attn.q_proj.weight", - "shared_transformer.self_attn.k_proj.weight", - "shared_transformer.self_attn.v_proj.weight", - "shared_transformer.self_attn.o_proj.weight", - "shared_transformer.feed_forward.gate_up_proj.weight", - "shared_transformer.feed_forward.down_proj.weight", - "shared_transformer.input_layernorm.weight", - "shared_transformer.pre_ff_layernorm.weight", - ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + # We will incorporate the 'layers.{layer_id}.' prefix into the patterns. + prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." + + # 1) Main shared keys (q/k/v/o_proj, gate_up_proj, down_proj, layernorms) + # combined into one pattern. You can separate these into multiple regex + # entries if you prefer finer granularity. + main_keys_pattern = re.compile( + prefix_pattern + + r"(?:" + + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" + + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" + + r"(?:input_layernorm|pre_ff_layernorm)\.weight" + + r")$" + ) + self._tied_weights_keys.append(main_keys_pattern) + + # 2) If using shared MLP adapter layers, create regex patterns for those. if self.config.use_shared_mlp_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: + # Only add keys for the relevant adapter_id / block_id if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." + # gate_up_proj_adapter_list.X.[0|1].weight + # Instead of storing multiple strings, store a single combined regex + adapter_pattern = re.compile( + r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." - + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + + # 3) If using shared Attention adapter layers, create regex patterns for those. if self.config.use_shared_attention_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." + # linear_q_adapter_list.X.[0|1].weight + # linear_k_adapter_list.X.[0|1].weight + # linear_v_adapter_list.X.[0|1].weight + # We'll combine them, but if you want separate patterns, split accordingly. + attn_adapter_pattern = re.compile( + r"^shared_transformer\.self_attn\." + + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(attn_adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + + # Construct the actual layer layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) + return layers diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 8a4e379b3ce..64d8cd8087c 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import re from itertools import cycle from typing import Callable, Optional, Tuple, Union @@ -976,81 +977,71 @@ def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] self._tied_weights_keys = [] self.first_transformer_layer_id = 0 + for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": if self.first_transformer_layer_id == 0: self.first_transformer_layer_id = layer_id + block = next(blocks) + + # If there are multiple shared blocks, we gather tied keys if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - prefix_name = f"layers.{layer_id}." - tied_keys = [ - "shared_transformer.self_attn.q_proj.weight", - "shared_transformer.self_attn.k_proj.weight", - "shared_transformer.self_attn.v_proj.weight", - "shared_transformer.self_attn.o_proj.weight", - "shared_transformer.feed_forward.gate_up_proj.weight", - "shared_transformer.feed_forward.down_proj.weight", - "shared_transformer.input_layernorm.weight", - "shared_transformer.pre_ff_layernorm.weight", - ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + # We will incorporate the 'layers.{layer_id}.' prefix into the patterns. + prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." + + # 1) Main shared keys (q/k/v/o_proj, gate_up_proj, down_proj, layernorms) + # combined into one pattern. You can separate these into multiple regex + # entries if you prefer finer granularity. + main_keys_pattern = re.compile( + prefix_pattern + + r"(?:" + + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" + + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" + + r"(?:input_layernorm|pre_ff_layernorm)\.weight" + + r")$" + ) + self._tied_weights_keys.append(main_keys_pattern) + + # 2) If using shared MLP adapter layers, create regex patterns for those. if self.config.use_shared_mlp_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: + # Only add keys for the relevant adapter_id / block_id if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." + # gate_up_proj_adapter_list.X.[0|1].weight + # Instead of storing multiple strings, store a single combined regex + adapter_pattern = re.compile( + r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." - + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + + # 3) If using shared Attention adapter layers, create regex patterns for those. if self.config.use_shared_attention_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." + # linear_q_adapter_list.X.[0|1].weight + # linear_k_adapter_list.X.[0|1].weight + # linear_v_adapter_list.X.[0|1].weight + # We'll combine them, but if you want separate patterns, split accordingly. + attn_adapter_pattern = re.compile( + r"^shared_transformer\.self_attn\." + + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(attn_adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + + # Construct the actual layer layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) + return layers def forward( From f701dbd471326039accbd6b3801c3be6d518ed84 Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 24 Jan 2025 01:16:12 +0000 Subject: [PATCH 65/73] Revert "use regex for tied keys" This reverts commit 9007a522b1f831df6d516a281c0d3fdd20a118f5. --- .../models/zamba2/modeling_zamba2.py | 95 ++++++++++--------- .../models/zamba2/modular_zamba2.py | 95 ++++++++++--------- 2 files changed, 104 insertions(+), 86 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 7c418beae9f..2148d0ce910 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -20,7 +20,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import re from itertools import cycle from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -1559,71 +1558,81 @@ def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] self._tied_weights_keys = [] self.first_transformer_layer_id = 0 - for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": if self.first_transformer_layer_id == 0: self.first_transformer_layer_id = layer_id - block = next(blocks) - - # If there are multiple shared blocks, we gather tied keys if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - # We will incorporate the 'layers.{layer_id}.' prefix into the patterns. - prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." - - # 1) Main shared keys (q/k/v/o_proj, gate_up_proj, down_proj, layernorms) - # combined into one pattern. You can separate these into multiple regex - # entries if you prefer finer granularity. - main_keys_pattern = re.compile( - prefix_pattern - + r"(?:" - + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" - + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" - + r"(?:input_layernorm|pre_ff_layernorm)\.weight" - + r")$" - ) - self._tied_weights_keys.append(main_keys_pattern) - - # 2) If using shared MLP adapter layers, create regex patterns for those. + prefix_name = f"layers.{layer_id}." + tied_keys = [ + "shared_transformer.self_attn.q_proj.weight", + "shared_transformer.self_attn.k_proj.weight", + "shared_transformer.self_attn.v_proj.weight", + "shared_transformer.self_attn.o_proj.weight", + "shared_transformer.feed_forward.gate_up_proj.weight", + "shared_transformer.feed_forward.down_proj.weight", + "shared_transformer.input_layernorm.weight", + "shared_transformer.pre_ff_layernorm.weight", + ] + self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] if self.config.use_shared_mlp_adapter: + tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: - # Only add keys for the relevant adapter_id / block_id if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - # gate_up_proj_adapter_list.X.[0|1].weight - # Instead of storing multiple strings, store a single combined regex - adapter_pattern = re.compile( - r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + str(adapter_id) - + r"\.(?:0|1)\.weight$" + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + + str(adapter_id) + + ".1.weight" ) - self._tied_weights_keys.append(adapter_pattern) adapter_id += 1 - - # 3) If using shared Attention adapter layers, create regex patterns for those. + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] if self.config.use_shared_attention_adapter: + tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - # linear_q_adapter_list.X.[0|1].weight - # linear_k_adapter_list.X.[0|1].weight - # linear_v_adapter_list.X.[0|1].weight - # We'll combine them, but if you want separate patterns, split accordingly. - attn_adapter_pattern = re.compile( - r"^shared_transformer\.self_attn\." - + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_list." + str(adapter_id) - + r"\.(?:0|1)\.weight$" + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".1.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_list." + + str(adapter_id) + + ".1.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".1.weight" ) - self._tied_weights_keys.append(attn_adapter_pattern) adapter_id += 1 - - # Construct the actual layer + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) - return layers diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 64d8cd8087c..8a4e379b3ce 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -import re from itertools import cycle from typing import Callable, Optional, Tuple, Union @@ -977,71 +976,81 @@ def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] self._tied_weights_keys = [] self.first_transformer_layer_id = 0 - for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": if self.first_transformer_layer_id == 0: self.first_transformer_layer_id = layer_id - block = next(blocks) - - # If there are multiple shared blocks, we gather tied keys if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - # We will incorporate the 'layers.{layer_id}.' prefix into the patterns. - prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." - - # 1) Main shared keys (q/k/v/o_proj, gate_up_proj, down_proj, layernorms) - # combined into one pattern. You can separate these into multiple regex - # entries if you prefer finer granularity. - main_keys_pattern = re.compile( - prefix_pattern - + r"(?:" - + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" - + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" - + r"(?:input_layernorm|pre_ff_layernorm)\.weight" - + r")$" - ) - self._tied_weights_keys.append(main_keys_pattern) - - # 2) If using shared MLP adapter layers, create regex patterns for those. + prefix_name = f"layers.{layer_id}." + tied_keys = [ + "shared_transformer.self_attn.q_proj.weight", + "shared_transformer.self_attn.k_proj.weight", + "shared_transformer.self_attn.v_proj.weight", + "shared_transformer.self_attn.o_proj.weight", + "shared_transformer.feed_forward.gate_up_proj.weight", + "shared_transformer.feed_forward.down_proj.weight", + "shared_transformer.input_layernorm.weight", + "shared_transformer.pre_ff_layernorm.weight", + ] + self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] if self.config.use_shared_mlp_adapter: + tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: - # Only add keys for the relevant adapter_id / block_id if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - # gate_up_proj_adapter_list.X.[0|1].weight - # Instead of storing multiple strings, store a single combined regex - adapter_pattern = re.compile( - r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + str(adapter_id) - + r"\.(?:0|1)\.weight$" + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.feed_forward.gate_up_proj_adapter_list." + + str(adapter_id) + + ".1.weight" ) - self._tied_weights_keys.append(adapter_pattern) adapter_id += 1 - - # 3) If using shared Attention adapter layers, create regex patterns for those. + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] if self.config.use_shared_attention_adapter: + tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - # linear_q_adapter_list.X.[0|1].weight - # linear_k_adapter_list.X.[0|1].weight - # linear_v_adapter_list.X.[0|1].weight - # We'll combine them, but if you want separate patterns, split accordingly. - attn_adapter_pattern = re.compile( - r"^shared_transformer\.self_attn\." - + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_list." + str(adapter_id) - + r"\.(?:0|1)\.weight$" + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".0.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_q_adapter_list." + + str(adapter_id) + + ".1.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_k_adapter_list." + + str(adapter_id) + + ".1.weight" + ) + tied_keys_adapter.append( + "shared_transformer.self_attn.linear_v_adapter_list." + + str(adapter_id) + + ".1.weight" ) - self._tied_weights_keys.append(attn_adapter_pattern) adapter_id += 1 - - # Construct the actual layer + self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) - return layers def forward( From 87b938b488a15df501b9fe7b522c78e9ba1d26eb Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 24 Jan 2025 01:29:29 +0000 Subject: [PATCH 66/73] use regex for tied keys --- .../models/zamba2/modeling_zamba2.py | 73 ++++++------------- .../models/zamba2/modular_zamba2.py | 73 ++++++------------- 2 files changed, 42 insertions(+), 104 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 2148d0ce910..4961802ef22 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import re from itertools import cycle from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -1564,72 +1565,40 @@ def get_layers(self, blocks, linear_layers, mamba_layers): self.first_transformer_layer_id = layer_id block = next(blocks) if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - prefix_name = f"layers.{layer_id}." - tied_keys = [ - "shared_transformer.self_attn.q_proj.weight", - "shared_transformer.self_attn.k_proj.weight", - "shared_transformer.self_attn.v_proj.weight", - "shared_transformer.self_attn.o_proj.weight", - "shared_transformer.feed_forward.gate_up_proj.weight", - "shared_transformer.feed_forward.down_proj.weight", - "shared_transformer.input_layernorm.weight", - "shared_transformer.pre_ff_layernorm.weight", - ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." + main_keys_pattern = re.compile( + prefix_pattern + + r"(?:" + + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" + + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" + + r"(?:input_layernorm|pre_ff_layernorm)\.weight" + + r")$" + ) + self._tied_weights_keys.append(main_keys_pattern) + if self.config.use_shared_mlp_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." + adapter_pattern = re.compile( + r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] if self.config.use_shared_attention_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." + attn_adapter_pattern = re.compile( + r"^shared_transformer\.self_attn\." + + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(attn_adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 8a4e379b3ce..d52f103f9ad 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import re from itertools import cycle from typing import Callable, Optional, Tuple, Union @@ -982,72 +983,40 @@ def get_layers(self, blocks, linear_layers, mamba_layers): self.first_transformer_layer_id = layer_id block = next(blocks) if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - prefix_name = f"layers.{layer_id}." - tied_keys = [ - "shared_transformer.self_attn.q_proj.weight", - "shared_transformer.self_attn.k_proj.weight", - "shared_transformer.self_attn.v_proj.weight", - "shared_transformer.self_attn.o_proj.weight", - "shared_transformer.feed_forward.gate_up_proj.weight", - "shared_transformer.feed_forward.down_proj.weight", - "shared_transformer.input_layernorm.weight", - "shared_transformer.pre_ff_layernorm.weight", - ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." + main_keys_pattern = re.compile( + prefix_pattern + + r"(?:" + + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" + + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" + + r"(?:input_layernorm|pre_ff_layernorm)\.weight" + + r")$" + ) + self._tied_weights_keys.append(main_keys_pattern) + if self.config.use_shared_mlp_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." + adapter_pattern = re.compile( + r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] if self.config.use_shared_attention_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." + attn_adapter_pattern = re.compile( + r"^shared_transformer\.self_attn\." + + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(attn_adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) From 5e092909971410e31d82531b718ed1ab0fe71221 Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 24 Jan 2025 02:49:05 +0000 Subject: [PATCH 67/73] add cpu to slow forward tests --- tests/models/zamba2/test_modeling_zamba2.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index bd3df9f6e59..2bd6732514c 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -582,8 +582,9 @@ def setUpClass(cls): ) cls.tokenizer = AutoTokenizer.from_pretrained(model_id, revision="PR") + @parameterized.expand([(torch_device,), ("cpu",)]) @slow - def test_simple_generate(self): + def test_simple_generate(self, torch_device): self.model.to(torch_device) input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[ @@ -610,8 +611,9 @@ def test_simple_generate(self): , dtype=torch.float32) # fmt: skip torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3) + @parameterized.expand([(torch_device,), ("cpu",)]) @slow - def test_simple_batched_generate_with_padding(self): + def test_simple_batched_generate_with_padding(self, torch_device): self.model.to(torch_device) inputs = self.tokenizer( @@ -656,4 +658,9 @@ def test_simple_batched_generate_with_padding(self): , dtype=torch.float32) # fmt: skip torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + logits[1, -1, :40].cpu(), + EXPECTED_LOGITS_NO_GRAD_1, + rtol=1e-3, + atol=6e-3 if torch_device == "cpu" else 1e-3, + ) From 8ed23534535afd0e17b252406878894744cb914c Mon Sep 17 00:00:00 2001 From: pglorio Date: Fri, 24 Jan 2025 03:12:44 +0000 Subject: [PATCH 68/73] dropped config.use_shared_mlp_adapter --- .../models/zamba2/configuration_zamba2.py | 4 -- .../models/zamba2/modeling_zamba2.py | 47 +++++++++---------- .../models/zamba2/modular_zamba2.py | 47 +++++++++---------- 3 files changed, 44 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py index 6e66bb53dea..975e9687358 100644 --- a/src/transformers/models/zamba2/configuration_zamba2.py +++ b/src/transformers/models/zamba2/configuration_zamba2.py @@ -83,8 +83,6 @@ class Zamba2Config(PretrainedConfig): The dropout ratio for the attention probabilities. num_mem_blocks (`int`, *optional*, defaults to 1): Number of unshared transformer blocks. - use_shared_mlp_adapter (`bool`, *optional*, defaults to `False`): - If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the shared MLP's. use_shared_attention_adapter (`bool`, *optional*, defaults to `False`): If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the q, k, v projectors in the shared attention layers. adapter_rank (`int`, *optional*, defaults to 128): @@ -152,7 +150,6 @@ def __init__( num_key_value_heads=None, attention_dropout=0.0, num_mem_blocks=1, - use_shared_mlp_adapter=False, use_shared_attention_adapter=False, adapter_rank=128, use_mem_rope=False, @@ -203,7 +200,6 @@ def __init__( self.use_conv_bias = use_conv_bias self.chunk_size = chunk_size self.time_step_limit = time_step_limit - self.use_shared_mlp_adapter = use_shared_mlp_adapter self.use_shared_attention_adapter = use_shared_attention_adapter self.adapter_rank = adapter_rank self.time_step_min = time_step_min diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 4961802ef22..0c3fa909707 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -987,26 +987,24 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) self.act_fn = ACT2FN[config.hidden_act] - if self.config.use_shared_mlp_adapter: - self.gate_up_proj_adapter_list = nn.ModuleList([]) - for i in range(self.num_fwd_mem_blocks): - if i % config.num_mem_blocks == block_id: - gate_up_proj_adapter = nn.Sequential( - nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False), - nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False), - ) - else: - gate_up_proj_adapter = nn.Identity() - self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) + self.gate_up_proj_adapter_list = nn.ModuleList([]) + for i in range(self.num_fwd_mem_blocks): + if i % config.num_mem_blocks == block_id: + gate_up_proj_adapter = nn.Sequential( + nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False), + ) + else: + gate_up_proj_adapter = nn.Identity() + self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) layer_block_map = config.hybrid_layer_ids self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx=None): gate_up_state = self.gate_up_proj(hidden_state) - if self.config.use_shared_mlp_adapter: - layer_idx = self.layer_dic[layer_idx] - gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state) + layer_idx = self.layer_dic[layer_idx] + gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state) gate_up_state = torch.chunk(gate_up_state, 2, dim=-1) hidden_state = self.act_fn(gate_up_state[0]) * gate_up_state[1] @@ -1576,17 +1574,16 @@ def get_layers(self, blocks, linear_layers, mamba_layers): ) self._tied_weights_keys.append(main_keys_pattern) - if self.config.use_shared_mlp_adapter: - adapter_id = 0 - for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - adapter_pattern = re.compile( - r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." - + str(adapter_id) - + r"\.(?:0|1)\.weight$" - ) - self._tied_weights_keys.append(adapter_pattern) - adapter_id += 1 + adapter_id = 0 + for _layer_type in self.layers_block_type: + if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: + adapter_pattern = re.compile( + r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + + str(adapter_id) + + r"\.(?:0|1)\.weight$" + ) + self._tied_weights_keys.append(adapter_pattern) + adapter_id += 1 if self.config.use_shared_attention_adapter: adapter_id = 0 for _layer_type in self.layers_block_type: diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index d52f103f9ad..dd62d48ac41 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -721,26 +721,24 @@ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear) self.act_fn = ACT2FN[config.hidden_act] - if self.config.use_shared_mlp_adapter: - self.gate_up_proj_adapter_list = nn.ModuleList([]) - for i in range(self.num_fwd_mem_blocks): - if i % config.num_mem_blocks == block_id: - gate_up_proj_adapter = nn.Sequential( - nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False), - nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False), - ) - else: - gate_up_proj_adapter = nn.Identity() - self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) + self.gate_up_proj_adapter_list = nn.ModuleList([]) + for i in range(self.num_fwd_mem_blocks): + if i % config.num_mem_blocks == block_id: + gate_up_proj_adapter = nn.Sequential( + nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False), + nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False), + ) + else: + gate_up_proj_adapter = nn.Identity() + self.gate_up_proj_adapter_list.append(gate_up_proj_adapter) layer_block_map = config.hybrid_layer_ids self.layer_dic = {value: index for index, value in enumerate(layer_block_map)} def forward(self, hidden_state, layer_idx=None): gate_up_state = self.gate_up_proj(hidden_state) - if self.config.use_shared_mlp_adapter: - layer_idx = self.layer_dic[layer_idx] - gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state) + layer_idx = self.layer_dic[layer_idx] + gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state) gate_up_state = torch.chunk(gate_up_state, 2, dim=-1) hidden_state = self.act_fn(gate_up_state[0]) * gate_up_state[1] @@ -994,17 +992,16 @@ def get_layers(self, blocks, linear_layers, mamba_layers): ) self._tied_weights_keys.append(main_keys_pattern) - if self.config.use_shared_mlp_adapter: - adapter_id = 0 - for _layer_type in self.layers_block_type: - if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - adapter_pattern = re.compile( - r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." - + str(adapter_id) - + r"\.(?:0|1)\.weight$" - ) - self._tied_weights_keys.append(adapter_pattern) - adapter_id += 1 + adapter_id = 0 + for _layer_type in self.layers_block_type: + if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: + adapter_pattern = re.compile( + r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + + str(adapter_id) + + r"\.(?:0|1)\.weight$" + ) + self._tied_weights_keys.append(adapter_pattern) + adapter_id += 1 if self.config.use_shared_attention_adapter: adapter_id = 0 for _layer_type in self.layers_block_type: From a9bbd9c15a9fc84013a23d39009890fe320181ff Mon Sep 17 00:00:00 2001 From: pglorio <85982602+pglorio@users.noreply.github.com> Date: Fri, 24 Jan 2025 09:56:20 -0800 Subject: [PATCH 69/73] Update docs/source/en/model_doc/zamba2.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/model_doc/zamba2.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/en/model_doc/zamba2.md b/docs/source/en/model_doc/zamba2.md index 4aa2a437ead..b331e10eaf8 100644 --- a/docs/source/en/model_doc/zamba2.md +++ b/docs/source/en/model_doc/zamba2.md @@ -34,7 +34,6 @@ Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space m Zamba2 requires you use `transformers` version 4.48.0 or higher: ```bash pip install transformers>=4.48.0 - ## Inference ```python From 37bff341dfaa384e370ecbbbe03bfcbe810a33a4 Mon Sep 17 00:00:00 2001 From: pglorio Date: Mon, 27 Jan 2025 06:55:47 +0000 Subject: [PATCH 70/73] re-convert from modular --- .../models/zamba2/modeling_zamba2.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 0c3fa909707..04ff9864941 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -38,6 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_zamba2 import Zamba2Config @@ -1632,6 +1633,7 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model + @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -1647,7 +1649,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, + logits_to_keep: Union[int, torch.Tensor] = 0, **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" @@ -1657,10 +1659,12 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - num_logits_to_keep (`int` or `None`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all - `input_ids`. Only last token logits are needed for generation, and calculating them only for that token - can save memory, which becomes pretty significant for long sequences. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: @@ -1704,7 +1708,8 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: @@ -1770,7 +1775,7 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, - "num_logits_to_keep": self.config.num_logits_to_keep, + "logits_to_keep": self.config.num_logits_to_keep, "cache_position": cache_position, } ) From cd304b5139198ac1c6ae822704347a7c645b6185 Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 28 Jan 2025 19:04:12 +0000 Subject: [PATCH 71/73] extended Zamba2RMSNormGated to n_groups>1 --- .../models/zamba2/modeling_zamba2.py | 16 +++++++----- .../models/zamba2/modular_zamba2.py | 26 ++++++++++++++++--- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 04ff9864941..48101201a94 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -25,6 +25,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +from einops import rearrange from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -62,20 +63,21 @@ class Zamba2RMSNormGated(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-6): + def __init__(self, hidden_size, group_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.group_size = group_size def forward(self, hidden_states, gate=None): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) - if gate is not None: hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - + hidden_states_group = rearrange(hidden_states, "... (g d) -> ... g d", d=self.group_size) + variance = hidden_states_group.pow(2).mean(-1, keepdim=True) + hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = rearrange(hidden_states_group, "... g d -> ... (g d)") return self.weight * hidden_states.to(input_dtype) @@ -601,7 +603,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True - self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5) + self.norm = Zamba2RMSNormGated( + self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5 + ) self.D = nn.Parameter(torch.ones(self.num_heads)) self.D._no_weight_decay = True diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index dd62d48ac41..26203c13753 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -20,6 +20,7 @@ import torch import torch.utils.checkpoint +from einops import rearrange from torch import nn from ...activations import ACT2FN @@ -35,7 +36,7 @@ is_mamba_ssm_available, ) from ..llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb -from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum +from ..mamba2.modeling_mamba2 import pad_tensor_by_size, reshape_into_chunks, segment_sum from ..zamba.modeling_zamba import ( ZambaAttention, ZambaAttentionDecoderLayer, @@ -70,8 +71,23 @@ logger = logging.get_logger(__name__) -class Zamba2RMSNormGated(MambaRMSNormGated): - pass +class Zamba2RMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, group_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_size = group_size + + def forward(self, hidden_states, gate=None): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + if gate is not None: + hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) + hidden_states_group = rearrange(hidden_states, "... (g d) -> ... g d", d=self.group_size) + variance = hidden_states_group.pow(2).mean(-1, keepdim=True) + hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = rearrange(hidden_states_group, "... g d -> ... (g d)") + return self.weight * hidden_states.to(input_dtype) class Zamba2RMSNorm(ZambaRMSNorm): @@ -334,7 +350,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None): A = torch.arange(1, self.num_heads + 1) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True - self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5) + self.norm = Zamba2RMSNormGated( + self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5 + ) self.D = nn.Parameter(torch.ones(self.num_heads)) self.D._no_weight_decay = True From 8f2eb7b9146531a6a9d8062d58fdbc721b10bcb4 Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 28 Jan 2025 19:36:47 +0000 Subject: [PATCH 72/73] removed einops import --- src/transformers/models/zamba2/modeling_zamba2.py | 7 ++++--- src/transformers/models/zamba2/modular_zamba2.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 48101201a94..c2313eeb91a 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -25,7 +25,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from einops import rearrange from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -74,10 +73,12 @@ def forward(self, hidden_states, gate=None): hidden_states = hidden_states.to(torch.float32) if gate is not None: hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - hidden_states_group = rearrange(hidden_states, "... (g d) -> ... g d", d=self.group_size) + *prefix_dims, last_dim = hidden_states.shape + group_count = last_dim // self.group_size + hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size) variance = hidden_states_group.pow(2).mean(-1, keepdim=True) hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = rearrange(hidden_states_group, "... g d -> ... (g d)") + hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size) return self.weight * hidden_states.to(input_dtype) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 26203c13753..8fdc9af2f51 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -20,7 +20,6 @@ import torch import torch.utils.checkpoint -from einops import rearrange from torch import nn from ...activations import ACT2FN @@ -83,10 +82,12 @@ def forward(self, hidden_states, gate=None): hidden_states = hidden_states.to(torch.float32) if gate is not None: hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32)) - hidden_states_group = rearrange(hidden_states, "... (g d) -> ... g d", d=self.group_size) + *prefix_dims, last_dim = hidden_states.shape + group_count = last_dim // self.group_size + hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size) variance = hidden_states_group.pow(2).mean(-1, keepdim=True) hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon) - hidden_states = rearrange(hidden_states_group, "... g d -> ... (g d)") + hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size) return self.weight * hidden_states.to(input_dtype) From be7d81ac7b64519c77322f515424f2a14527a045 Mon Sep 17 00:00:00 2001 From: pglorio Date: Tue, 28 Jan 2025 21:42:30 +0000 Subject: [PATCH 73/73] set _supports_sdpa = True --- src/transformers/models/zamba2/modeling_zamba2.py | 2 +- src/transformers/models/zamba2/modular_zamba2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index c2313eeb91a..15876282cb9 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1232,7 +1232,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_flex_attn = True - _supports_sdpa = False + _supports_sdpa = True _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache _is_stateful = True diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 8fdc9af2f51..6e921764513 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -915,7 +915,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_flex_attn = True - _supports_sdpa = False + _supports_sdpa = True _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache _is_stateful = True