From b2c29b1e54de02f2b19f49851a661c6481fa70c9 Mon Sep 17 00:00:00 2001 From: Yeonsil Yoon Date: Wed, 11 Sep 2024 14:30:00 -0700 Subject: [PATCH] Falcon inference crash fix for falcon-40b model (#1161) --- .../models/falcon/modeling_falcon.py | 86 +++++++++++-------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index a7a0c0e920..52fc649948 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -87,6 +87,40 @@ def gaudi_falcon_linear_forward(self, input: torch.Tensor) -> torch.Tensor: return hidden_states +def repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): + """ + Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. + - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) + """ + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask + + # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA): @@ -123,40 +157,6 @@ def __init__(self, config: FalconConfig): self.softmax = Softmax() self.num_key_value_groups = config.num_attention_heads // config.num_kv_heads - def repeat_kv( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask: torch.Tensor, - n_rep: int, - ): - """ - Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - The only differences are: - - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. - - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. - The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) - The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) - """ - batch, num_key_value_heads, kv_len, head_dim = key_states.shape - if n_rep == 1 or num_key_value_heads == 1: - return query_states, key_states, value_states, attention_mask - - new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) - key_states = key_states.reshape(new_kv_shape) - value_states = value_states.reshape(new_kv_shape) - - batch, _, q_len, head_dim = query_states.shape - new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) - query_states = query_states.reshape(new_q_shape) - - if attention_mask is not None: - # Add groups dim and set to 1 - attention_mask = attention_mask.unsqueeze(1) - - return query_states, key_states, value_states, attention_mask - def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(self.head_dim) @@ -173,7 +173,7 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa if attn_mask.dtype == torch.bool: attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) - query, key, value, attn_mask = self.repeat_kv(query, key, value, attn_mask, self.num_key_value_groups) + query, key, value, attn_mask = repeat_kv(query, key, value, attn_mask, self.num_key_value_groups) attn_weight = self.bmm1(query, key.transpose(-2, -1)) attn_weight += attn_mask @@ -262,7 +262,7 @@ def __init__(self, config: FalconConfig): # TODO, Does this affect memory usage? if self.is_fp8: self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) - self.unfused_scaled_dot_product_attention = ScaledDotProductAttention(config) + self.unfused_scaled_dot_product_attention = ScaledDotProductAttention(config) self.k_cache = KVCache() self.v_cache = KVCache() @@ -353,7 +353,11 @@ def pre_attn_forward( train_with_flash_attention = self.training and self._use_sdpa and not output_attentions and head_mask is None (query_layer, key_layer, value_layer) = self._split_heads( - fused_qkv, not use_flash_attention and not self.is_fp8 and not train_with_flash_attention + fused_qkv, + not use_flash_attention + and not self.is_fp8 + and not train_with_flash_attention + and not (self.config.num_kv_heads == 8), ) batch_size, query_length, _, _ = query_layer.shape @@ -462,6 +466,14 @@ def pre_attn_forward( query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False ) else: + if query_layer.shape != key_layer.shape: + query_layer, key_layer, value_layer, attention_mask = repeat_kv( + query_layer, + key_layer, + value_layer, + attention_mask, + self.config.num_attention_heads // self.config.num_kv_heads, + ) # Workaround util scaled_dot_product_attention support broadcast. if self.training is True and query_layer.shape != key_layer.shape: key_layer = torch.broadcast_to(key_layer, query_layer.shape)