Skip to content

Commit

Permalink
Falcon inference crash fix for falcon-40b model (huggingface#1161)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonsily authored Sep 11, 2024
1 parent feb6545 commit b2c29b1
Showing 1 changed file with 49 additions and 37 deletions.
86 changes: 49 additions & 37 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,40 @@ def gaudi_falcon_linear_forward(self, input: torch.Tensor) -> torch.Tensor:
return hidden_states


def repeat_kv(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
n_rep: int,
):
"""
Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them.
- Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion.
The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim)
The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim)
"""
batch, num_key_value_heads, kv_len, head_dim = key_states.shape
if n_rep == 1 or num_key_value_heads == 1:
return query_states, key_states, value_states, attention_mask

new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim)
key_states = key_states.reshape(new_kv_shape)
value_states = value_states.reshape(new_kv_shape)

batch, _, q_len, head_dim = query_states.shape
new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim)
query_states = query_states.reshape(new_q_shape)

if attention_mask is not None:
# Add groups dim and set to 1
attention_mask = attention_mask.unsqueeze(1)

return query_states, key_states, value_states, attention_mask


# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
Expand Down Expand Up @@ -123,40 +157,6 @@ def __init__(self, config: FalconConfig):
self.softmax = Softmax()
self.num_key_value_groups = config.num_attention_heads // config.num_kv_heads

def repeat_kv(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
n_rep: int,
):
"""
Copied from repeat_kv: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them.
- Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion.
The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim)
The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim)
"""
batch, num_key_value_heads, kv_len, head_dim = key_states.shape
if n_rep == 1 or num_key_value_heads == 1:
return query_states, key_states, value_states, attention_mask

new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim)
key_states = key_states.reshape(new_kv_shape)
value_states = value_states.reshape(new_kv_shape)

batch, _, q_len, head_dim = query_states.shape
new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim)
query_states = query_states.reshape(new_q_shape)

if attention_mask is not None:
# Add groups dim and set to 1
attention_mask = attention_mask.unsqueeze(1)

return query_states, key_states, value_states, attention_mask

def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(self.head_dim)
Expand All @@ -173,7 +173,7 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))

query, key, value, attn_mask = self.repeat_kv(query, key, value, attn_mask, self.num_key_value_groups)
query, key, value, attn_mask = repeat_kv(query, key, value, attn_mask, self.num_key_value_groups)

attn_weight = self.bmm1(query, key.transpose(-2, -1))
attn_weight += attn_mask
Expand Down Expand Up @@ -262,7 +262,7 @@ def __init__(self, config: FalconConfig):
# TODO, Does this affect memory usage?
if self.is_fp8:
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA)
self.unfused_scaled_dot_product_attention = ScaledDotProductAttention(config)
self.unfused_scaled_dot_product_attention = ScaledDotProductAttention(config)

self.k_cache = KVCache()
self.v_cache = KVCache()
Expand Down Expand Up @@ -353,7 +353,11 @@ def pre_attn_forward(

train_with_flash_attention = self.training and self._use_sdpa and not output_attentions and head_mask is None
(query_layer, key_layer, value_layer) = self._split_heads(
fused_qkv, not use_flash_attention and not self.is_fp8 and not train_with_flash_attention
fused_qkv,
not use_flash_attention
and not self.is_fp8
and not train_with_flash_attention
and not (self.config.num_kv_heads == 8),
)

batch_size, query_length, _, _ = query_layer.shape
Expand Down Expand Up @@ -462,6 +466,14 @@ def pre_attn_forward(
query_layer, key_layer, value_layer, attention_mask, 0.0, is_causal=False
)
else:
if query_layer.shape != key_layer.shape:
query_layer, key_layer, value_layer, attention_mask = repeat_kv(
query_layer,
key_layer,
value_layer,
attention_mask,
self.config.num_attention_heads // self.config.num_kv_heads,
)
# Workaround util scaled_dot_product_attention support broadcast.
if self.training is True and query_layer.shape != key_layer.shape:
key_layer = torch.broadcast_to(key_layer, query_layer.shape)
Expand Down

0 comments on commit b2c29b1

Please sign in to comment.