From 69ba6c9cd096ec81395d109a38dad843bcfb2217 Mon Sep 17 00:00:00 2001 From: xiaohan Date: Sat, 25 Jan 2025 21:42:01 +0800 Subject: [PATCH] add three files. --- .../models/mimi/configuration_mimi.py | 237 +++ .../transformers/models/mimi/modeling_mimi.py | 1788 +++++++++++++++++ .../models/mimi/test_modeling_mimi.py | 860 ++++++++ 3 files changed, 2885 insertions(+) create mode 100644 mindnlp/transformers/models/mimi/configuration_mimi.py create mode 100644 mindnlp/transformers/models/mimi/modeling_mimi.py create mode 100644 tests/transformers/models/mimi/test_modeling_mimi.py diff --git a/mindnlp/transformers/models/mimi/configuration_mimi.py b/mindnlp/transformers/models/mimi/configuration_mimi.py new file mode 100644 index 000000000..bf11801bf --- /dev/null +++ b/mindnlp/transformers/models/mimi/configuration_mimi.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mimi model configuration""" + +import math + +import numpy as np + +from ...configuration_utils import PretrainedConfig +from ....utils import logging + + +logger = logging.get_logger(__name__) + + +class MimiConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of an [`MimiModel`]. It is used to instantiate a + Mimi model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the + [kyutai/mimi](https://huggingface.co/kyutai/mimi) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + sampling_rate (`int`, *optional*, defaults to 24000): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + frame_rate (`float`, *optional*, defaults to 12.5): + Framerate of the model. + audio_channels (`int`, *optional*, defaults to 1): + Number of channels in the audio data. Either 1 for mono or 2 for stereo. + hidden_size (`int`, *optional*, defaults to 512): + Intermediate representation dimension. + num_filters (`int`, *optional*, defaults to 64): + Number of convolution kernels of first `MimiConv1d` down sampling layer. + num_residual_layers (`int`, *optional*, defaults to 1): + Number of residual layers. + upsampling_ratios (`Sequence[int]`, *optional*): + Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it + will use the ratios in the reverse order to the ones specified here that must match the decoder order. + If not specified, will defaults to `[8, 6, 5, 4]` + kernel_size (`int`, *optional*, defaults to 7): + Kernel size for the initial convolution. + last_kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the last convolution layer. + residual_kernel_size (`int`, *optional*, defaults to 3): + Kernel size for the residual layers. + dilation_growth_rate (`int`, *optional*, defaults to 2): + How much to increase the dilation with each layer. + use_causal_conv (`bool`, *optional*, defaults to `True`): + Whether to use fully causal convolution. + pad_mode (`str`, *optional*, defaults to `"constant"`): + Padding mode for the convolutions. + compress (`int`, *optional*, defaults to 2): + Reduced dimensionality in residual branches. + trim_right_ratio (`float`, *optional*, defaults to 1.0): + Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If + equal to 1.0, it means that all the trimming is done at the right. + codebook_size (`int`, *optional*, defaults to 2048): + Number of discret codes in each codebooks. + codebook_dim (`int`, *optional*, defaults to 256): + Dimension of the unquantized codebook vectors. If not defined, uses `hidden_size`. + num_quantizers (`int`, *optional*, defaults to 32): + Number of quantizer channels, or codebooks, in the quantizer. + use_conv_shortcut (`bool`, *optional*, defaults to `False`): + Whether to use a convolutional layer as the 'skip' connection in the `MimiResnetBlock` block. If False, + an identity function will be used, giving a generic residual connection. + vector_quantization_hidden_dimension (`int`, *optional*, defaults to 256): + Intermediate representation dimension in the residual vector quantization space. + num_semantic_quantizers (`int`, *optional*, defaults to 1): + Number of semantic quantizer channels, or codebooks, in the semantic quantizer. Must be lower than `num_quantizers`. + upsample_groups (`int`, *optional*, defaults to 512): + If `frame_rate!=encodec_frame_rate`, indicates the number of groups used in the upsampling operation to go from one rate to another. + num_hidden_layers (`int`, *optional*, defaults to 8): + Number of hidden layers in the Transformer models. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MLP representations. + num_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 8000): + The maximum sequence length that this model might ever be used with. Mimi's sliding window attention + allows sequence of up to 8000 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the LayerNorm normalization layers. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*, defaults to 250): + Sliding window attention window size. If not specified, will default to `250`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + layer_scale_initial_scale (`float`, *optional*, defaults to 0.01): + Initiale scale of the residual rescaling operation done in the Transformer models. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + Example: + + ```python + >>> from transformers import MimiModel, MimiConfig + + >>> # Initializing a "kyutai/mimi" style configuration + >>> configuration = MimiConfig() + + >>> # Initializing a model (with random weights) from the "kyutai/mimi" style configuration + >>> model = MimiModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mimi" + + def __init__( + self, + sampling_rate=24_000, + frame_rate=12.5, + audio_channels=1, + hidden_size=512, + num_filters=64, + num_residual_layers=1, + upsampling_ratios=None, + kernel_size=7, + last_kernel_size=3, + residual_kernel_size=3, + dilation_growth_rate=2, + use_causal_conv=True, + pad_mode="constant", + compress=2, + trim_right_ratio=1.0, + codebook_size=2048, + codebook_dim=256, + num_quantizers=32, + use_conv_shortcut=False, + vector_quantization_hidden_dimension=256, + num_semantic_quantizers=1, + upsample_groups=512, + num_hidden_layers=8, + intermediate_size=2048, + num_attention_heads=8, + num_key_value_heads=8, + head_dim=None, + hidden_act="gelu", + max_position_embeddings=8000, + initializer_range=0.02, + norm_eps=1e-5, + use_cache=False, + rope_theta=10000.0, + sliding_window=250, + attention_dropout=0.0, + layer_scale_initial_scale=0.01, + attention_bias=False, + **kwargs, + ): + self.sampling_rate = sampling_rate + self.frame_rate = frame_rate + self.audio_channels = audio_channels + self.hidden_size = hidden_size + self.num_filters = num_filters + self.num_residual_layers = num_residual_layers + self.upsampling_ratios = upsampling_ratios if upsampling_ratios else [8, 6, 5, 4] + self.kernel_size = kernel_size + self.last_kernel_size = last_kernel_size + self.residual_kernel_size = residual_kernel_size + self.dilation_growth_rate = dilation_growth_rate + self.use_causal_conv = use_causal_conv + self.pad_mode = pad_mode + self.compress = compress + self.trim_right_ratio = trim_right_ratio + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size + self.num_quantizers = num_quantizers + self.use_conv_shortcut = use_conv_shortcut + self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension + self.upsample_groups = upsample_groups + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.sliding_window = sliding_window + self.attention_dropout = attention_dropout + self.head_dim = head_dim or hidden_size // num_attention_heads + self.layer_scale_initial_scale = layer_scale_initial_scale + self.attention_bias = attention_bias + + if num_semantic_quantizers >= self.num_quantizers: + raise ValueError( + f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}." + ) + self.num_semantic_quantizers = num_semantic_quantizers + super().__init__(**kwargs) + + @property + def encodec_frame_rate(self) -> int: + hop_length = np.prod(self.upsampling_ratios) + return math.ceil(self.sampling_rate / hop_length) + + @property + def num_codebooks(self) -> int: + # alias to num_quantizers + return self.num_quantizers + + +__all__ = ["MimiConfig"] diff --git a/mindnlp/transformers/models/mimi/modeling_mimi.py b/mindnlp/transformers/models/mimi/modeling_mimi.py new file mode 100644 index 000000000..87dbc9394 --- /dev/null +++ b/mindnlp/transformers/models/mimi/modeling_mimi.py @@ -0,0 +1,1788 @@ +# coding=utf-8 +# Copyright 2024 Kyutai, and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mimi model.""" + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import mindspore +from mindnlp.core import nn, ops +from mindnlp.core.nn import functional as F + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ....utils import ( + ModelOutput, + logging, +) +from .configuration_mimi import MimiConfig + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +# General docstring +_CONFIG_FOR_DOC = "MimiConfig" + + +@dataclass +class MimiOutput(ModelOutput): + """ + Args: + audio_codes (`mindspore.Tensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + audio_values (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*) + Decoded audio values, obtained using the decoder part of Mimi. + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + decoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + """ + + audio_codes: mindspore.Tensor = None + audio_values: mindspore.Tensor = None + encoder_past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None + decoder_past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None + + +@dataclass +class MimiEncoderOutput(ModelOutput): + """ + Args: + audio_codes (`mindspore.Tensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + """ + + audio_codes: mindspore.Tensor = None + encoder_past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None + + +@dataclass +class MimiDecoderOutput(ModelOutput): + """ + Args: + audio_values (`mindspore.Tensor` of shape `(batch_size, segment_length)`, *optional*): + Decoded audio values, obtained using the decoder part of Mimi. + decoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + """ + + audio_values: mindspore.Tensor = None + decoder_past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None + + +class MimiConv1d(nn.Module): + """Conv1d with asymmetric or causal padding and normalization.""" + + def __init__( + self, + config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + pad_mode=None, + bias: bool = True, + ): + super().__init__() + self.causal = config.use_causal_conv + self.pad_mode = config.pad_mode if pad_mode is None else pad_mode + + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + logger.warning( + "MimiConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups, bias=bias + ) + + kernel_size = self.conv.kernel_size[0] + stride = mindspore.tensor(self.conv.stride[0], dtype=mindspore.int3264) + dilation = self.conv.dilation[0] + + # Effective kernel size with dilations. + kernel_size = mindspore.tensor((kernel_size - 1) * dilation + 1, dtype=mindspore.int3264) + + self.register_buffer("stride", stride, persistent=False) + self.register_buffer("kernel_size", kernel_size, persistent=False) + self.register_buffer("padding_total", mindspore.tensor(kernel_size - stride, dtype=mindspore.int3264), persistent=False) + + # Asymmetric padding required for odd strides + self.padding_right = self.padding_total // 2 + self.padding_left = self.padding_total - self.padding_right + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._get_extra_padding_for_conv1d + def _get_extra_padding_for_conv1d( + self, + hidden_states: mindspore.Tensor, + ) -> mindspore.Tensor: + """See `pad_for_conv1d`.""" + length = hidden_states.shape[-1] + n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1 + n_frames = ops.ceil(n_frames).to(mindspore.int3264) - 1 + ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total + + return ideal_length - length + + @staticmethod + # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._pad1d + def _pad1d(hidden_states: mindspore.Tensor, paddings: Tuple[int, int], mode: str = "zero", value: float = 0.0): + """Tiny wrapper around ops.nn.functional.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happens. + """ + length = hidden_states.shape[-1] + padding_left, padding_right = paddings + if not mode == "reflect": + return nn.functional.pad(hidden_states, paddings, mode, value) + + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + hidden_states = nn.functional.pad(hidden_states, (0, extra_pad)) + padded = nn.functional.pad(hidden_states, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + + def forward(self, hidden_states): + extra_padding = self._get_extra_padding_for_conv1d(hidden_states) + + if self.causal: + # Left padding for causal + hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode) + else: + hidden_states = self._pad1d( + hidden_states, (self.padding_left, self.padding_right + extra_padding), mode=self.pad_mode + ) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class MimiConvTranspose1d(nn.Module): + """ConvTranspose1d with asymmetric or causal padding and normalization.""" + + def __init__( + self, + config, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + bias=True, + ): + super().__init__() + self.causal = config.use_causal_conv + self.trim_right_ratio = config.trim_right_ratio + self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias) + + if not (self.causal or self.trim_right_ratio == 1.0): + raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions") + + kernel_size = self.conv.kernel_size[0] + stride = self.conv.stride[0] + padding_total = kernel_size - stride + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if self.causal: + # Trim the padding on the right according to the specified ratio + # if trim_right_ratio = 1.0, trim everything from right + self.padding_right = math.ceil(padding_total * self.trim_right_ratio) + else: + # Asymmetric padding required for odd strides + self.padding_right = padding_total // 2 + + self.padding_left = padding_total - self.padding_right + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + weight_norm(self.conv) + + def remove_weight_norm(self): + nn.utils.remove_weight_norm(self.conv) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + # unpad + end = hidden_states.shape[-1] - self.padding_right + hidden_states = hidden_states[..., self.padding_left : end] + return hidden_states + + +# Copied from transformers.models.encodec.modeling_encodec.EncodecResnetBlock with Encodec->Mimi,EnCodec->Mimi +class MimiResnetBlock(nn.Module): + """ + Residual block from SEANet model as used by Mimi. + """ + + def __init__(self, config: MimiConfig, dim: int, dilations: List[int]): + super().__init__() + kernel_sizes = (config.residual_kernel_size, 1) + if len(kernel_sizes) != len(dilations): + raise ValueError("Number of kernel sizes should match number of dilations") + + hidden = dim // config.compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [nn.ELU()] + block += [MimiConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)] + self.block = nn.ModuleList(block) + + if config.use_conv_shortcut: + self.shortcut = MimiConv1d(config, dim, dim, kernel_size=1) + else: + self.shortcut = nn.Identity() + + def forward(self, hidden_states): + residual = hidden_states + for layer in self.block: + hidden_states = layer(hidden_states) + + return self.shortcut(residual) + hidden_states + + +class MimiEncoder(nn.Module): + """SEANet encoder as used by Mimi.""" + + def __init__(self, config: MimiConfig): + super().__init__() + model = [MimiConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)] + scaling = 1 + + # Downsample to raw audio scale + for ratio in reversed(config.upsampling_ratios): + current_scale = scaling * config.num_filters + # Add residual layers + for j in range(config.num_residual_layers): + model += [MimiResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])] + # Add downsampling layers + model += [nn.ELU()] + model += [MimiConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)] + scaling *= 2 + + model += [nn.ELU()] + model += [MimiConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)] + + self.layers = nn.ModuleList(model) + + # Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class MimiLayerScale(nn.Module): + """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). + This rescales diagonally the residual outputs close to 0, with a learnt scale. + """ + + def __init__(self, config): + super().__init__() + channels = config.hidden_size + initial_scale = config.layer_scale_initial_scale + self.scale = nn.Parameter(ops.full((channels,), initial_scale, requires_grad=True)) + + def forward(self, x: mindspore.Tensor): + return self.scale * x + + +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi +class MimiRotaryEmbedding(nn.Module): + def __init__(self, config: MimiConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = ops.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + self.original_inv_freq = self.original_inv_freq.to(device) + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @ops.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with ops.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = ops.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return ops.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`mindspore.Tensor`): The query tensor. + k (`mindspore.Tensor`): The key tensor. + cos (`mindspore.Tensor`): The cosine part of the rotary embedding. + sin (`mindspore.Tensor`): The sine part of the rotary embedding. + position_ids (`mindspore.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(mindspore.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MimiMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + # Copied from transformers.models.clip.modeling_clip.CLIPMLP.forward + def forward(self, hidden_states: mindspore.Tensor) -> mindspore.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: mindspore.Tensor, n_rep: int) -> mindspore.Tensor: + """ + This is the equivalent of ops.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi +# no longer copied after attention refactors +class MimiAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = MimiRotaryEmbedding(config) + self.sliding_window = config.sliding_window # Ignore copy + + def forward( + self, + hidden_states: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[mindspore.Tensor] = None, + ) -> Tuple[mindspore.Tensor, Optional[mindspore.Tensor], Optional[Tuple[mindspore.Tensor]]]: + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = ops.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=mindspore.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = ops.matmul(attn_weights, value_states) + + if attn_output.shape != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.shape}" + ) + + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi +# TODO cyril: modular +class MimiFlashAttention2(MimiAttention): + """ + Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[mindspore.Tensor] = None, + ) -> Tuple[mindspore.Tensor, Optional[mindspore.Tensor], Optional[Tuple[mindspore.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (MimiRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == mindspore.float32: + if ops.is_autocast_enabled(): + target_dtype = ops.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi +# TODO cyril: modular +class MimiSdpaAttention(MimiAttention): + """ + Mimi attention module using ops.nn.functional.scaled_dot_product_attention. This module inherits from + `MimiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MimiAttention.forward + def forward( + self, + hidden_states: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[mindspore.Tensor] = None, + **kwargs, + ) -> Tuple[mindspore.Tensor, Optional[mindspore.Tensor], Optional[Tuple[mindspore.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MimiModel is using MimiSdpaAttention, but `ops.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.shape + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states + key_states = key_states + value_states = value_states + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both ops.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = ops.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MIMI_ATTENTION_CLASSES = { + "eager": MimiAttention, + "flash_attention_2": MimiFlashAttention2, + "sdpa": MimiSdpaAttention, +} + + +class MimiTransformerLayer(nn.Module): + def __init__(self, config: MimiConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = MimiMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps) + self.self_attn_layer_scale = MimiLayerScale(config) + self.mlp_layer_scale = MimiLayerScale(config) + + def forward( + self, + hidden_states: mindspore.Tensor, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[mindspore.Tensor] = None, + **kwargs, + ) -> Tuple[mindspore.Tensor, Optional[Tuple[mindspore.Tensor, mindspore.Tensor]]]: + """ + Args: + hidden_states (`mindspore.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`mindspore.Tensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(mindspore.Tensor)`, *optional*): cached past key and value projection states + cache_position (`mindspore.Tensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + self.self_attn_layer_scale(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.mlp_layer_scale(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MimiTransformerModel(nn.Module): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MimiTransformerLayer`] + + Args: + config: MimiConfig + """ + + def __init__(self, config: MimiConfig): + super().__init__() + + self.layers = nn.ModuleList( + [MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: mindspore.Tensor = None, + attention_mask: Optional[mindspore.Tensor] = None, + position_ids: Optional[mindspore.Tensor] = None, + past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[mindspore.Tensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Args: + hidden_states (`mindspore.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Embedded representation that will be contextualized by the model + attention_mask (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(mindspore.Tensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(mindspore.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if use_cache and not isinstance(past_key_values, Cache): + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = ops.arange( + past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = None + if attention_mask is not None: + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Mimi + def _update_causal_mask( + self, + attention_mask: mindspore.Tensor, + input_tensor: mindspore.Tensor, + cache_position: mindspore.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.shape[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mimi. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = ops.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, mindspore.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mimi + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: mindspore.Tensor, + sequence_length: int, + target_length: int, + dtype: mindspore.dtype, + device: mindspore.device, + cache_position: mindspore.Tensor, + batch_size: int, + config: MimiConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`mindspore.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`mindspore.dtype`): + The dtype to use for the 4D attention mask. + device (`mindspore.device`): + The device to plcae the 4D attention mask on. + cache_position (`mindspore.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`mindspore.Tensor`): + Batch size. + config (`MimiConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = ops.finfo(dtype).min + causal_mask = ops.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = ops.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = ops.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class MimiDecoder(nn.Module): + """SEANet decoder as used by Mimi.""" + + def __init__(self, config: MimiConfig): + super().__init__() + scaling = int(2 ** len(config.upsampling_ratios)) + model = [MimiConv1d(config, config.hidden_size, scaling * config.num_filters, config.kernel_size)] + + # Upsample to raw audio scale + for ratio in config.upsampling_ratios: + current_scale = scaling * config.num_filters + # Add upsampling layers + model += [nn.ELU()] + model += [ + MimiConvTranspose1d(config, current_scale, current_scale // 2, kernel_size=ratio * 2, stride=ratio) + ] + # Add residual layers + for j in range(config.num_residual_layers): + model += [MimiResnetBlock(config, current_scale // 2, (config.dilation_growth_rate**j, 1))] + scaling //= 2 + + # Add final layers + model += [nn.ELU()] + model += [MimiConv1d(config, config.num_filters, config.audio_channels, config.last_kernel_size)] + self.layers = nn.ModuleList(model) + + # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoder.forward + def forward(self, hidden_states): + for layer in self.layers: + hidden_states = layer(hidden_states) + return hidden_states + + +class MimiEuclideanCodebook(nn.Module): + """Codebook with Euclidean distance.""" + + def __init__(self, config: MimiConfig, epsilon: float = 1e-5): + super().__init__() + embed = ops.zeros(config.codebook_size, config.codebook_dim) + + self.codebook_size = config.codebook_size + + self.register_buffer("initialized", mindspore.tensor([True], dtype=mindspore.float32)) + self.register_buffer("cluster_usage", ops.ones(config.codebook_size)) + self.register_buffer("embed_sum", embed) + self._embed = None + self.epsilon = epsilon + + @property + def embed(self) -> mindspore.Tensor: + if self._embed is None: + self._embed = self.embed_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None] + return self._embed + + def quantize(self, hidden_states): + # Projects each vector in `hidden_states` over the nearest centroid and return its index. + # `hidden_states` should be `[N, D]` with `N` the number of input vectors and `D` the dimension. + dists = ops.cdist(hidden_states[None], self.embed[None], p=2)[0] + embed_ind = dists.argmin(dim=-1) + return embed_ind + + # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.encode + def encode(self, hidden_states): + shape = hidden_states.shape + # pre-process + hidden_states = hidden_states.reshape((-1, shape[-1])) + # quantize + embed_ind = self.quantize(hidden_states) + # post-process + embed_ind = embed_ind.view(*shape[:-1]) + return embed_ind + + # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.decode + def decode(self, embed_ind): + quantize = nn.functional.embedding(embed_ind, self.embed) + return quantize + + +# Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization with Encodec->Mimi +class MimiVectorQuantization(nn.Module): + """ + Vector quantization implementation. Currently supports only euclidean distance. + """ + + def __init__(self, config: MimiConfig): + super().__init__() + self.codebook = MimiEuclideanCodebook(config) + + def encode(self, hidden_states): + hidden_states = hidden_states.permute(0, 2, 1) + embed_in = self.codebook.encode(hidden_states) + return embed_in + + def decode(self, embed_ind): + quantize = self.codebook.decode(embed_ind) + quantize = quantize.permute(0, 2, 1) + return quantize + + +class MimiResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer.""" + + def __init__(self, config: MimiConfig, num_quantizers: int = None): + super().__init__() + self.codebook_size = config.codebook_size + self.frame_rate = config.frame_rate + self.num_quantizers = num_quantizers if num_quantizers is not None else config.num_quantizers + self.layers = nn.ModuleList([MimiVectorQuantization(config) for _ in range(self.num_quantizers)]) + + self.input_proj = None + self.output_proj = None + if config.vector_quantization_hidden_dimension != config.hidden_size: + self.input_proj = ops.nn.Conv1d( + config.hidden_size, config.vector_quantization_hidden_dimension, 1, bias=False + ) + self.output_proj = ops.nn.Conv1d( + config.vector_quantization_hidden_dimension, config.hidden_size, 1, bias=False + ) + + def encode(self, embeddings: mindspore.Tensor, num_quantizers: Optional[int] = None) -> mindspore.Tensor: + """ + Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets + the appropriate number of quantizers to use and returns indices for each quantizer. + """ + if self.input_proj is not None: + embeddings = self.input_proj(embeddings) + + num_quantizers = num_quantizers if num_quantizers is not None else self.num_quantizers + + residual = embeddings + all_indices = [] + for layer in self.layers[:num_quantizers]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = ops.stack(all_indices) + return out_indices + + def decode(self, codes: mindspore.Tensor) -> mindspore.Tensor: + """Decode the given codes of shape [B, K, T] to the quantized representation.""" + quantized_out = mindspore.tensor(0.0, device=codes.device) + codes = codes.transpose(0, 1) + for i, indices in enumerate(codes): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + + if self.output_proj is not None: + quantized_out = self.output_proj(quantized_out) + return quantized_out + + +class MimiSplitResidualVectorQuantizer(nn.Module): + """Split Residual Vector Quantizer.""" + + def __init__(self, config: MimiConfig): + super().__init__() + self.codebook_size = config.codebook_size + self.frame_rate = config.frame_rate + self.max_num_quantizers = config.num_quantizers + + self.num_semantic_quantizers = config.num_semantic_quantizers + self.num_acoustic_quantizers = config.num_quantizers - config.num_semantic_quantizers + + self.semantic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_semantic_quantizers) + self.acoustic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_acoustic_quantizers) + + def encode(self, embeddings: mindspore.Tensor, num_quantizers: Optional[float] = None) -> mindspore.Tensor: + """ + Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets + the appropriate number of quantizers to use and returns indices for each quantizer. + """ + + num_quantizers = self.max_num_quantizers if num_quantizers is None else num_quantizers + + if num_quantizers > self.max_num_quantizers: + raise ValueError( + f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.max_num_quantizers}, but is currently {num_quantizers}." + ) + + if num_quantizers < self.num_semantic_quantizers: + raise ValueError( + f"The number of quantizers (i.e codebooks) asked should be higher than the number of semantic quantizers {self.num_semantic_quantizers}, but is currently {num_quantizers}." + ) + + # codes is [K, B, T], with T frames, K nb of codebooks. + codes = self.semantic_residual_vector_quantizer.encode(embeddings) + + if num_quantizers > self.num_semantic_quantizers: + acoustic_codes = self.acoustic_residual_vector_quantizer.encode( + embeddings, num_quantizers=num_quantizers - self.num_semantic_quantizers + ) + codes = ops.cat([codes, acoustic_codes], dim=0) + + return codes + + def decode(self, codes: mindspore.Tensor) -> mindspore.Tensor: + """Decode the given codes to the quantized representation.""" + + # The first num_semantic_quantizers codebooks are decoded using the semantic RVQ + quantized_out = self.semantic_residual_vector_quantizer.decode(codes[:, : self.num_semantic_quantizers]) + + # The rest of the codebooks are decoded using the acoustic RVQ + if codes.shape[1] > self.num_semantic_quantizers: + quantized_out += self.acoustic_residual_vector_quantizer.decode(codes[:, self.num_semantic_quantizers :]) + return quantized_out + + +class MimiPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MimiConfig + base_model_prefix = "mimi" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _no_split_modules = ["MimiDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + # Copied from transformers.models.encodec.modeling_encodec.EncodecPreTrainedModel._init_weights + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LSTM): + for name, param in module.named_parameters(): + if "weight" in name: + nn.init.xavier_uniform_(param) + elif "bias" in name: + nn.init.constant_(param, 0.0) + + +MIMI_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [ops.nn.Module](https://pytorch.org/docs/stable/nn.html#ops.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MimiConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +MIMI_INPUTS_DOCSTRING = r""" + Args: + input_values (`mindspore.Tensor` of shape `(batch_size, channels, sequence_length)`, *optional*): + Raw audio input converted to Float. + padding_mask (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0 + for *masked*. + num_quantizers (`int`, *optional*): + Number of quantizers (i.e codebooks) to use. By default, all quantizers are used. + audio_codes (`mindspore.Tensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + decoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + + +class MimiModel(MimiPreTrainedModel): + def __init__(self, config: MimiConfig): + super().__init__(config) + self.config = config + + self.encoder = MimiEncoder(config) + self.encoder_transformer = MimiTransformerModel(config) + + self.downsample = None + self.upsample = None + if config.frame_rate != config.encodec_frame_rate: + self.downsample = MimiConv1d( + config, + config.hidden_size, + config.hidden_size, + kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), + stride=2, + bias=False, + pad_mode="replicate", + ) + + self.upsample = MimiConvTranspose1d( + config, + config.hidden_size, + config.hidden_size, + kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate), + stride=2, + bias=False, + groups=config.upsample_groups, + ) + + self.decoder_transformer = MimiTransformerModel(config) + self.decoder = MimiDecoder(config) + + self.quantizer = MimiSplitResidualVectorQuantizer(config) + + self.bits_per_codebook = int(math.log2(self.config.codebook_size)) + if 2**self.bits_per_codebook != self.config.codebook_size: + raise ValueError("The codebook_size must be a power of 2.") + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _encode_frame( + self, + input_values: mindspore.Tensor, + num_quantizers: int, + padding_mask: int, + past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None, + return_dict: Optional[bool] = None, + ) -> Tuple[mindspore.Tensor, Optional[mindspore.Tensor]]: + """ + Encodes the given input using the underlying VQVAE. The padding mask is required to compute the correct scale. + """ + embeddings = self.encoder(input_values) + encoder_outputs = self.encoder_transformer( + embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict + ) + if return_dict: + past_key_values = encoder_outputs.get("past_key_values") + elif len(encoder_outputs) > 1: + past_key_values = encoder_outputs[1] + embeddings = encoder_outputs[0].transpose(1, 2) + embeddings = self.downsample(embeddings) + + codes = self.quantizer.encode(embeddings, num_quantizers) + codes = codes.transpose(0, 1) + return codes, past_key_values + + def encode( + self, + input_values: mindspore.Tensor, + padding_mask: mindspore.Tensor = None, + num_quantizers: Optional[float] = None, + encoder_past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[mindspore.Tensor, Optional[mindspore.Tensor]], MimiEncoderOutput]: + """ + Encodes the input audio waveform into discrete codes. + + Args: + input_values (`mindspore.Tensor` of shape `(batch_size, channels, sequence_length)`): + Float values of the input audio waveform. + padding_mask (`mindspore.Tensor` of shape `(batch_size, channels, sequence_length)`): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0 + for *masked*. + num_quantizers (`int`, *optional*): + Number of quantizers (i.e codebooks) to use. By default, all quantizers are used. + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + `codebook` of shape `[batch_size, num_codebooks, frames]`, the discrete encoded codes for the input audio waveform. + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + num_quantizers = self.config.num_quantizers if num_quantizers is None else num_quantizers + + if num_quantizers > self.config.num_quantizers: + raise ValueError( + f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.config.num_quantizers}, but is currently {num_quantizers}." + ) + + _, channels, input_length = input_values.shape + + if channels < 1 or channels > 2: + raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}") + + if padding_mask is None: + padding_mask = ops.ones_like(input_values).bool() + + encoded_frames, encoder_past_key_values = self._encode_frame( + input_values, + num_quantizers, + padding_mask.bool(), + past_key_values=encoder_past_key_values, + return_dict=return_dict, + ) + + if not return_dict: + return ( + encoded_frames, + encoder_past_key_values, + ) + + return MimiEncoderOutput(encoded_frames, encoder_past_key_values) + + def _decode_frame( + self, + codes: mindspore.Tensor, + past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None, + return_dict: Optional[bool] = None, + ) -> mindspore.Tensor: + embeddings = self.quantizer.decode(codes) + + embeddings = self.upsample(embeddings) + decoder_outputs = self.decoder_transformer( + embeddings.transpose(1, 2), past_key_values=past_key_values, return_dict=return_dict + ) + if return_dict: + past_key_values = decoder_outputs.get("past_key_values") + elif len(decoder_outputs) > 1: + past_key_values = decoder_outputs[1] + embeddings = decoder_outputs[0].transpose(1, 2) + outputs = self.decoder(embeddings) + return outputs, past_key_values + + def decode( + self, + audio_codes: mindspore.Tensor, + padding_mask: Optional[mindspore.Tensor] = None, + decoder_past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[mindspore.Tensor, mindspore.Tensor], MimiDecoderOutput]: + """ + Decodes the given frames into an output audio waveform. + + Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be + trimmed. + + Args: + audio_codes (`mindspore.Tensor` of shape `(batch_size, num_quantizers, codes_length)`, *optional*): + Discret code embeddings computed using `model.encode`. + padding_mask (`mindspore.Tensor` of shape `(batch_size, channels, sequence_length)`): + Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0 + for *masked*. + decoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer. + This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + The model will output the same cache format that is fed as input. + + If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't + have their past key value states given to this model). + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + """ + return_dict = return_dict if return_dict is not None else self.config.return_dict + + audio_values, decoder_past_key_values = self._decode_frame( + audio_codes, past_key_values=decoder_past_key_values, return_dict=return_dict + ) + + # truncate based on padding mask + if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]: + audio_values = audio_values[..., : padding_mask.shape[-1]] + + if not return_dict: + return ( + audio_values, + decoder_past_key_values, + ) + return MimiDecoderOutput(audio_values, decoder_past_key_values) + + + def forward( + self, + input_values: mindspore.Tensor, + padding_mask: Optional[mindspore.Tensor] = None, + num_quantizers: Optional[int] = None, + audio_codes: Optional[mindspore.Tensor] = None, + encoder_past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None, + decoder_past_key_values: Optional[Union[Cache, List[mindspore.Tensor]]] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[mindspore.Tensor, mindspore.Tensor], MimiOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from datasets import load_dataset + >>> from transformers import AutoFeatureExtractor, MimiModel + + >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example") + >>> audio_sample = dataset["train"]["audio"][0]["array"] + + >>> model_id = "kyutai/mimi" + >>> model = MimiModel.from_pretrained(model_id) + >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id) + + >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> audio_codes = outputs.audio_codes + >>> audio_values = outputs.audio_values + ```""" + return_dict = return_dict if return_dict is not None else self.config.return_dict + + if padding_mask is None: + padding_mask = ops.ones_like(input_values).bool() + + if audio_codes is None: + encoder_outputs = self.encode( + input_values, padding_mask, num_quantizers, encoder_past_key_values, return_dict=return_dict + ) + audio_codes = encoder_outputs[0] + if return_dict: + encoder_past_key_values = encoder_outputs.get("past_key_values") + elif len(encoder_outputs) > 1: + encoder_past_key_values = encoder_outputs[1] + + decoder_outputs = self.decode(audio_codes, padding_mask, decoder_past_key_values, return_dict=return_dict) + audio_values = decoder_outputs[0] + if return_dict: + decoder_past_key_values = decoder_outputs.get("past_key_values") + elif len(decoder_outputs) > 1: + decoder_past_key_values = decoder_outputs[1] + + if not return_dict: + return (audio_codes, audio_values, encoder_past_key_values, decoder_past_key_values) + + return MimiOutput( + audio_codes=audio_codes, + audio_values=audio_values, + encoder_past_key_values=encoder_past_key_values, + decoder_past_key_values=decoder_past_key_values, + ) + + +__all__ = ["MimiModel", "MimiPreTrainedModel"] diff --git a/tests/transformers/models/mimi/test_modeling_mimi.py b/tests/transformers/models/mimi/test_modeling_mimi.py new file mode 100644 index 000000000..932203b6f --- /dev/null +++ b/tests/transformers/models/mimi/test_modeling_mimi.py @@ -0,0 +1,860 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Mimi model.""" + +import inspect +import os +import tempfile +import unittest + +import numpy as np +from datasets import Audio, load_dataset +from parameterized import parameterized +from pytest import mark + +from mindnlp.transformers import AutoFeatureExtractor, MimiConfig +from mindnlp.transformers.testing_utils import ( + is_flaky, + is_torch_available, + require_flash_attn, + require_torch, + require_torch_gpu, + require_torch_sdpa, + slow, + torch_device, +) +from mindnlp.transformers.utils import ( + is_torch_bf16_available_on_device, + is_torch_fp16_available_on_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel + + +if is_torch_available(): + import torch + + from transformers import MimiModel + + +# Copied from transformers.tests.encodec.test_modeling_encodec.prepare_inputs_dict +def prepare_inputs_dict( + config, + input_ids=None, + input_values=None, + decoder_input_ids=None, + attention_mask=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, +): + if input_ids is not None: + encoder_dict = {"input_ids": input_ids} + else: + encoder_dict = {"input_values": input_values} + + decoder_dict = {"decoder_input_ids": decoder_input_ids} if decoder_input_ids is not None else {} + + return {**encoder_dict, **decoder_dict} + + +@require_torch +class MimiModelTester: + def __init__( + self, + parent, + batch_size=5, + num_channels=1, + is_training=False, + intermediate_size=40, + hidden_size=32, + num_filters=8, + num_residual_layers=1, + upsampling_ratios=[8, 4], + codebook_size=64, + vector_quantization_hidden_dimension=64, + codebook_dim=64, + upsample_groups=32, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + sliding_window=4, + use_cache=False, + ): + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.is_training = is_training + self.intermediate_size = intermediate_size + self.hidden_size = hidden_size + self.num_filters = num_filters + self.num_residual_layers = num_residual_layers + self.upsampling_ratios = upsampling_ratios + self.codebook_size = codebook_size + self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension + self.codebook_dim = codebook_dim + self.upsample_groups = upsample_groups + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.sliding_window = sliding_window + self.use_cache = use_cache + + def prepare_config_and_inputs(self): + input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0) + config = self.get_config() + inputs_dict = {"input_values": input_values} + return config, inputs_dict + + def prepare_config_and_inputs_for_common(self): + config, inputs_dict = self.prepare_config_and_inputs() + return config, inputs_dict + + def prepare_config_and_inputs_for_model_class(self, model_class): + config, inputs_dict = self.prepare_config_and_inputs() + inputs_dict["audio_codes"] = ids_tensor([self.batch_size, 1, self.num_channels], self.codebook_size).type( + torch.int32 + ) + + return config, inputs_dict + + def get_config(self): + return MimiConfig( + audio_channels=self.num_channels, + chunk_in_sec=None, + hidden_size=self.hidden_size, + num_filters=self.num_filters, + num_residual_layers=self.num_residual_layers, + upsampling_ratios=self.upsampling_ratios, + codebook_size=self.codebook_size, + vector_quantization_hidden_dimension=self.vector_quantization_hidden_dimension, + upsample_groups=self.upsample_groups, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + num_key_value_heads=self.num_key_value_heads, + sliding_window=self.sliding_window, + codebook_dim=self.codebook_dim, + use_cache=self.use_cache, + ) + + def create_and_check_model_forward(self, config, inputs_dict): + model = MimiModel(config=config).to(torch_device).eval() + + input_values = inputs_dict["input_values"] + result = model(input_values) + self.parent.assertEqual( + result.audio_values.shape, (self.batch_size, self.num_channels, self.intermediate_size) + ) + + +@require_torch +class MimiModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (MimiModel,) if is_torch_available() else () + is_encoder_decoder = True + test_pruning = False + test_headmasking = False + test_resize_embeddings = False + test_torchscript = False + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + # model does support returning hidden states + inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) + if "output_attentions" in inputs_dict: + inputs_dict.pop("output_attentions") + if "output_hidden_states" in inputs_dict: + inputs_dict.pop("output_hidden_states") + return inputs_dict + + def setUp(self): + self.model_tester = MimiModelTester(self) + self.config_tester = ConfigTester( + self, config_class=MimiConfig, hidden_size=37, common_properties=[], has_text_modality=False + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_forward(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_values", "padding_mask", "num_quantizers"] + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + + @unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="The MimiModel does not have the usual `attention` logic") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="The MimiModel does not have the usual `attention` logic") + def test_torchscript_output_attentions(self): + pass + + @unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic") + def test_torchscript_output_hidden_state(self): + pass + + # Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest._create_and_check_torchscript + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + self.skipTest(reason="test_torchscript is set to False") + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + configs_no_init.return_dict = False + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + inputs = self._prepare_for_class(inputs_dict, model_class) + + main_input_name = model_class.main_input_name + + try: + main_input = inputs[main_input_name] + model(main_input) + traced_model = torch.jit.trace(model, main_input) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + if layer_name in loaded_model_state_dict: + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + # Avoid memory leak. Without this, each call increase RAM usage by ~20MB. + # (Even with this call, there are still memory leak by ~0.04MB) + self.clear_torch_jit_class_registry() + + @unittest.skip(reason="The MimiModel does not have the usual `attention` logic") + def test_attention_outputs(self): + pass + + @unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic") + def test_hidden_states_output(self): + pass + + # Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_determinism + def test_determinism(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_determinism(first, second): + # outputs are not tensors but list (since each sequence don't have the same frame_length) + out_1 = first.cpu().numpy() + out_2 = second.cpu().numpy() + out_1 = out_1[~np.isnan(out_1)] + out_2 = out_2[~np.isnan(out_2)] + max_diff = np.amax(np.abs(out_1 - out_2)) + self.assertLessEqual(max_diff, 1e-5) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + first = model(**self._prepare_for_class(inputs_dict, model_class))[0] + second = model(**self._prepare_for_class(inputs_dict, model_class))[0] + + if isinstance(first, tuple) and isinstance(second, tuple): + for tensor1, tensor2 in zip(first, second): + check_determinism(tensor1, tensor2) + else: + check_determinism(first, second) + + # Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_model_outputs_equivalence + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs) + + self.assertTrue(isinstance(tuple_output, tuple)) + self.assertTrue(isinstance(dict_output, dict)) + + for tuple_value, dict_value in zip(tuple_output, dict_output.values()): + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_value - dict_value))}. Tuple has `nan`:" + f" {torch.isnan(tuple_value).any()} and `inf`: {torch.isinf(tuple_value)}. Dict has" + f" `nan`: {torch.isnan(dict_value).any()} and `inf`: {torch.isinf(dict_value)}." + ), + ) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + uniform_init_parms = ["conv", "input_proj", "output_proj"] + if param.requires_grad: + if any(x in name for x in uniform_init_parms): + self.assertTrue( + -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + # Copied from transformers.tests.encodec.test_modeling_encodec.MimiModelTest.test_identity_shortcut + def test_identity_shortcut(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs() + config.use_conv_shortcut = False + self.model_tester.create_and_check_model_forward(config, inputs_dict) + + # Overwrite to use `audio_values` as the tensors to compare. + # TODO: Try to do this in the parent class. + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + if torch_dtype == "float16" and torch_device == "cpu": + self.skipTest("`replication_pad1d` not implemented for 'Half") + + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + # FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors. + # These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask. + # This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code. + # However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it. + deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters + + is_encoder_decoder = model.config.is_encoder_decoder + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + model_sdpa = model_sdpa.eval().to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + for output_attentions in [True, False]: + can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters + if not (self.has_attentions and can_output_attn) and output_attentions: + continue + for batch_size in [7]: + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + if dummy_input.shape[0] != batch_size: + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + extension = torch.rand( + batch_size - dummy_input.shape[0], + *dummy_input.shape[1:], + dtype=torch_dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + else: + extension = torch.randint( + high=5, + size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]), + dtype=dummy_input.dtype, + device=torch_device, + ) + dummy_input = torch.cat((dummy_input, extension), dim=0).to(torch_device) + + if not use_mask: + dummy_attention_mask = None + else: + dummy_attention_mask = inputs_dict.get("attention_mask", None) + if dummy_attention_mask is None: + if is_encoder_decoder: + seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1] + else: + seqlen = dummy_input.shape[-1] + dummy_attention_mask = ( + torch.ones(batch_size, seqlen).to(torch.int64).to(torch_device) + ) + + dummy_attention_mask = dummy_attention_mask[:batch_size] + if dummy_attention_mask.shape[0] != batch_size: + extension = torch.ones( + batch_size - dummy_attention_mask.shape[0], + *dummy_attention_mask.shape[1:], + dtype=dummy_attention_mask.dtype, + device=torch_device, + ) + dummy_attention_mask = torch.cat((dummy_attention_mask, extension), dim=0) + dummy_attention_mask = dummy_attention_mask.to(torch_device) + + dummy_attention_mask[:] = 1 + if padding_side == "left": + dummy_attention_mask[-1, :2] = 0 + dummy_attention_mask[-1, 2:] = 1 + elif padding_side == "right": + dummy_attention_mask[-1, -2:] = 0 + dummy_attention_mask[-1, :-2] = 1 + + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}" + if is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[ + :batch_size + ] + if decoder_input_ids.shape[0] != batch_size: + extension = torch.ones( + batch_size - decoder_input_ids.shape[0], + *decoder_input_ids.shape[1:], + dtype=decoder_input_ids.dtype, + device=torch_device, + ) + decoder_input_ids = torch.cat((decoder_input_ids, extension), dim=0) + decoder_input_ids = decoder_input_ids.to(torch_device) + + # TODO: never an `attention_mask` arg here? + processed_inputs = { + model.main_input_name: dummy_input, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + else: + processed_inputs = { + model.main_input_name: dummy_input, + "output_hidden_states": True, + } + + # Otherwise fails for e.g. WhisperEncoderModel + if "attention_mask" in inspect.signature(model_eager.forward).parameters: + processed_inputs["attention_mask"] = dummy_attention_mask + + if ( + self.has_attentions + and "output_attentions" in inspect.signature(model_sdpa.forward).parameters + ): + processed_inputs["output_attentions"] = output_attentions + if not deactivate_mask and ( + "bool_masked_pos" in inspect.signature(model_eager.forward).parameters + ): + dummy_mask = torch.ones((self.model_tester.num_masks,)) + + # In case of additional token (like class) we define a custom `mask_length` + if hasattr(self.model_tester, "mask_length"): + mask_length = self.model_tester.mask_length - dummy_mask.size(0) + else: + mask_length = self.model_tester.seq_length - dummy_mask.size(0) + dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) + dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() + processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) + + if "noise" in inspect.signature(model_eager.forward).parameters: + np.random.seed(2) + num_patches = int( + (self.model_tester.image_size // self.model_tester.patch_size) ** 2 + ) + noise = np.random.uniform(size=(batch_size, num_patches)) + processed_inputs["noise"] = torch.from_numpy(noise) + + # TODO: test gradients as well (& for FA2 as well!) + with torch.no_grad(): + with sdpa_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + + # Ignore copy + logits_eager = outputs_eager.audio_values + # Ignore copy + logits_sdpa = outputs_sdpa.audio_values + + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + _logits_sdpa = torch.zeros_like(input=logits_sdpa) + _logits_eager = torch.zeros_like(input=logits_eager) + + _logits_sdpa[:-1] = logits_sdpa[:-1] + _logits_eager[:-1] = logits_eager[:-1] + + if padding_side == "left": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] + _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] + + elif padding_side == "right": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] + _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] + + logits_sdpa = _logits_sdpa + logits_eager = _logits_eager + + results = [ + torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) + for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) + ] + # If 80% batch elements have matched results, it's fine + if np.mean(results) < 0.8: + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + @is_flaky() + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + outputs = model(dummy_input) + outputs_fa = model_fa(dummy_input) + + logits = outputs[1] + logits_fa = outputs_fa[1] + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + @unittest.skip(reason="The MimiModel does not support right padding") + def test_flash_attn_2_inference_equivalence_right_padding(self): + pass + + @unittest.skip(reason="The MimiModel does not have support dynamic compile yet") + def test_sdpa_can_compile_dynamic(self): + pass + + +# Copied from transformers.tests.encodec.test_modeling_encodec.normalize +def normalize(arr): + norm = np.linalg.norm(arr) + normalized_arr = arr / norm + return normalized_arr + + +# Copied from transformers.tests.encodec.test_modeling_encodec.compute_rmse +def compute_rmse(arr1, arr2): + arr1_normalized = normalize(arr1) + arr2_normalized = normalize(arr2) + return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean()) + + +@slow +@require_torch +class MimiIntegrationTest(unittest.TestCase): + def test_integration_using_cache_decode(self): + expected_rmse = { + "8": 0.0018785292, + "32": 0.0012330565, + } + + librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + model_id = "kyutai/mimi" + + model = MimiModel.from_pretrained(model_id, use_cache=True).to(torch_device) + processor = AutoFeatureExtractor.from_pretrained(model_id) + + librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate)) + audio_sample = librispeech_dummy[-1]["audio"]["array"] + + inputs = processor( + raw_audio=audio_sample, + sampling_rate=processor.sampling_rate, + return_tensors="pt", + ).to(torch_device) + + for num_codebooks, expected_rmse in expected_rmse.items(): + with torch.no_grad(): + # use max bandwith for best possible reconstruction + encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks)) + + audio_codes = encoder_outputs[0] + + decoder_outputs_first_part = model.decode(audio_codes[:, :, : audio_codes.shape[2] // 2]) + decoder_outputs_second_part = model.decode( + audio_codes[:, :, audio_codes.shape[2] // 2 :], + decoder_past_key_values=decoder_outputs_first_part.decoder_past_key_values, + ) + + audio_output_entire_context = model.decode(audio_codes)[0] + audio_output_concat_context = torch.cat( + [decoder_outputs_first_part[0], decoder_outputs_second_part[0]], dim=2 + ) + + # make sure audios are more or less equal + # the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0 + rmse = compute_rmse( + audio_output_concat_context.squeeze().cpu().numpy(), + audio_output_entire_context.squeeze().cpu().numpy(), + ) + self.assertTrue(rmse < 1e-3) + + def test_integration(self): + expected_rmses = { + "8": 0.0018785292, + "32": 0.0012330565, + } + expected_codesums = { + "8": 426176, + "32": 1795819, + } + librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + model_id = "kyutai/mimi" + + processor = AutoFeatureExtractor.from_pretrained(model_id) + + librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate)) + audio_sample = librispeech_dummy[-1]["audio"]["array"] + + inputs = processor( + raw_audio=audio_sample, + sampling_rate=processor.sampling_rate, + return_tensors="pt", + ).to(torch_device) + + for use_cache in [False, True]: + model = MimiModel.from_pretrained(model_id, use_cache=use_cache).to(torch_device) + for num_codebooks, expected_rmse in expected_rmses.items(): + with torch.no_grad(): + # use max bandwith for best possible reconstruction + encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks)) + + audio_code_sums = encoder_outputs[0].sum().cpu().item() + + # make sure audio encoded codes are correct + # assert relative difference less than a threshold, because `audio_code_sums` varies a bit + # depending on torch version + self.assertTrue( + np.abs(audio_code_sums - expected_codesums[num_codebooks]) <= (3e-3 * audio_code_sums) + ) + + input_values_dec = model.decode(encoder_outputs[0], padding_mask=inputs["padding_mask"])[0] + input_values_enc_dec = model( + inputs["input_values"], inputs["padding_mask"], num_quantizers=int(num_codebooks) + )[1] + + # make sure forward and decode gives same result + torch.testing.assert_close(input_values_dec, input_values_enc_dec) + + # make sure shape matches + self.assertTrue(inputs["input_values"].shape == input_values_enc_dec.shape) + + arr = inputs["input_values"][0].cpu().numpy() + arr_enc_dec = input_values_enc_dec[0].cpu().numpy() + + # make sure audios are more or less equal + # the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0 + rmse = compute_rmse(arr, arr_enc_dec) + self.assertTrue(np.abs(rmse - expected_rmse) < 1e-5)