Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RMSNormGated in Zamba2 #35943

Open
wants to merge 95 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 94 commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
acd25b7
First commit
pglorio Oct 24, 2024
70639b8
Finish model implementation
pglorio Oct 28, 2024
d111b98
First commit
pglorio Oct 24, 2024
8f36dba
Finish model implementation
pglorio Oct 28, 2024
f0c547c
Merge branch 'zamba2' of https://github.com/Zyphra/transformers_zamba…
pglorio Oct 29, 2024
700fbf0
Register zamba2
pglorio Oct 30, 2024
70a6021
generated modeling and configuration
pglorio Nov 4, 2024
88c4b26
Merge pull request #2 from Zyphra/main
pglorio Nov 5, 2024
685906a
generated modeling and configuration
pglorio Nov 5, 2024
4da8d5f
added hybrid cache
pglorio Nov 5, 2024
6b5a9be
fix attention_mask in mamba
pglorio Nov 5, 2024
248350d
dropped unused loras
pglorio Nov 5, 2024
d1d2c66
fix flash2
pglorio Nov 5, 2024
eb6063e
Merge pull request #3 from Zyphra/main
pglorio Nov 5, 2024
5f5d01e
config docstrings
Nov 6, 2024
c1b7647
fix config and fwd pass
pglorio Nov 7, 2024
979b99b
make fixup fixes
pglorio Nov 7, 2024
9d9b2eb
text_modeling_zamba2
pglorio Nov 9, 2024
3a457f5
Merge pull request #4 from Zyphra/main
pglorio Nov 9, 2024
549d4cb
small fixes
pglorio Nov 9, 2024
987bba9
make fixup fixes
pglorio Nov 11, 2024
ffc2a58
Merge pull request #5 from Zyphra/main
pglorio Nov 11, 2024
9adf85e
Fix modular model converter
pglorio Nov 11, 2024
904da4e
added inheritances in modular, renamed zamba cache
pglorio Nov 19, 2024
4725983
Merge pull request #6 from Zyphra/main
pglorio Nov 19, 2024
0be27d7
modular rebase
pglorio Nov 19, 2024
cc0c549
Rebase
pglorio Nov 19, 2024
ac77a09
new modular conversion
pglorio Nov 20, 2024
e59980e
fix generated modeling file
pglorio Nov 20, 2024
73a647a
fixed import for Zamba2RMSNormGated
pglorio Nov 20, 2024
c2b72a5
modular file cleanup
pglorio Nov 21, 2024
0eb39a5
rebase
pglorio Nov 21, 2024
10a0b1e
make fixup and model tests
pglorio Nov 21, 2024
0270667
dropped inheritance for Zamba2PreTrainedModel
pglorio Nov 23, 2024
189c8c5
make fixup and unit tests
pglorio Nov 23, 2024
fa5f79e
Add inheritance of rope from GemmaRotaryEmbedding
pglorio Dec 5, 2024
8079ae0
moved rope to model init
pglorio Dec 5, 2024
d6206eb
drop del self.self_attn and del self.feed_forward
pglorio Dec 5, 2024
f832699
Rebase onto upstream
pglorio Dec 5, 2024
cf613b7
fix tests
pglorio Dec 5, 2024
337faed
renamed lora -> adapter
pglorio Dec 7, 2024
f1b31a1
rewrote adapter implementation
pglorio Dec 7, 2024
8925c15
rebase
pglorio Dec 7, 2024
11fdd47
fixed tests
pglorio Dec 7, 2024
02dd042
Merge branch 'main' into zamba2
pglorio Dec 18, 2024
5d0a5d4
Fix torch_forward in mamba2 layer
pglorio Dec 19, 2024
ef055c9
Fix torch_forward in mamba2 layer
pglorio Dec 19, 2024
b993a78
Fix torch_forward in mamba2 layer
pglorio Dec 19, 2024
bf93251
Dropped adapter in-place sum
pglorio Dec 19, 2024
99708af
removed rope from attention init
pglorio Dec 19, 2024
d9b4a50
updated rope
pglorio Dec 19, 2024
095d853
created get_layers method
pglorio Dec 19, 2024
10ebad5
rebase
pglorio Dec 20, 2024
99e343e
make fixup fix
pglorio Dec 20, 2024
4e40975
make fixup fixes
pglorio Dec 20, 2024
61bb32f
make fixup fixes
pglorio Dec 20, 2024
bb9b24b
fix merge conflicts
pglorio Jan 7, 2025
cb90bb4
update to new attention standard
pglorio Jan 13, 2025
8ed701e
fixes for merge
pglorio Jan 13, 2025
1dbc8c7
update to new attention standard
pglorio Jan 13, 2025
f24e452
make fixup fixes
pglorio Jan 13, 2025
676f862
rebase
pglorio Jan 16, 2025
2b29338
minor fixes
pglorio Jan 16, 2025
b212cb2
cache_position
pglorio Jan 16, 2025
1e3b51e
removed cache_position postion_ids use_cache
pglorio Jan 16, 2025
5ace701
remove config from modular
pglorio Jan 16, 2025
535b631
removed config from modular (2)
pglorio Jan 16, 2025
5a16aa9
rebase
pglorio Jan 16, 2025
1c92266
import apply_rotary_pos_emb from llama
pglorio Jan 16, 2025
99bde93
fixed rope_kwargs
pglorio Jan 16, 2025
baf2ed3
Instantiate cache in Zamba2Model
pglorio Jan 16, 2025
9afb57e
fix cache
pglorio Jan 17, 2025
d1687f9
fix @slow decorator
pglorio Jan 17, 2025
4299889
rebase
pglorio Jan 20, 2025
a0545bf
rebase
pglorio Jan 21, 2025
903f6dc
small fix in modular file
pglorio Jan 21, 2025
14396d7
Update docs/source/en/model_doc/zamba2.md
pglorio Jan 23, 2025
02f5807
several minor fixes
pglorio Jan 23, 2025
bfb0267
inherit mamba2decoder fwd and drop position_ids in mamba
pglorio Jan 23, 2025
b222943
removed docstrings from modular
pglorio Jan 23, 2025
b114ad8
rebase
pglorio Jan 23, 2025
929ee67
reinstate zamba2 attention decoder fwd
pglorio Jan 23, 2025
9007a52
use regex for tied keys
pglorio Jan 24, 2025
f701dbd
Revert "use regex for tied keys"
pglorio Jan 24, 2025
87b938b
use regex for tied keys
pglorio Jan 24, 2025
5e09290
add cpu to slow forward tests
pglorio Jan 24, 2025
8ed2353
dropped config.use_shared_mlp_adapter
pglorio Jan 24, 2025
a9bbd9c
Update docs/source/en/model_doc/zamba2.md
pglorio Jan 24, 2025
1e82757
rebase
pglorio Jan 27, 2025
37bff34
re-convert from modular
pglorio Jan 27, 2025
8e0084c
resolve merge conflicts
pglorio Jan 28, 2025
cd304b5
extended Zamba2RMSNormGated to n_groups>1
pglorio Jan 28, 2025
8f2eb7b
removed einops import
pglorio Jan 28, 2025
be7d81a
set _supports_sdpa = True
pglorio Jan 28, 2025
de9a442
rebase
pglorio Feb 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,23 @@


class Zamba2RMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, group_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.group_size = group_size

def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)

if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)

*prefix_dims, last_dim = hidden_states.shape
group_count = last_dim // self.group_size
hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)
return self.weight * hidden_states.to(input_dtype)


Expand Down Expand Up @@ -601,7 +604,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None):
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5)
self.norm = Zamba2RMSNormGated(
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5
)
self.D = nn.Parameter(torch.ones(self.num_heads))
self.D._no_weight_decay = True

Expand Down Expand Up @@ -1227,7 +1232,7 @@ class Zamba2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_flex_attn = True
_supports_sdpa = False
_supports_sdpa = True
_supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache
_is_stateful = True

Expand Down
29 changes: 24 additions & 5 deletions src/transformers/models/zamba2/modular_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
is_mamba_ssm_available,
)
from ..llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum
from ..mamba2.modeling_mamba2 import pad_tensor_by_size, reshape_into_chunks, segment_sum
from ..zamba.modeling_zamba import (
ZambaAttention,
ZambaAttentionDecoderLayer,
Expand Down Expand Up @@ -70,8 +70,25 @@
logger = logging.get_logger(__name__)


class Zamba2RMSNormGated(MambaRMSNormGated):
pass
class Zamba2RMSNormGated(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will also affect the mamba2 code then (as codestral mamba also uses ngroups > 1) - so I'd be for implementing this in the mamba2 code and use modular then.

cc @molbap

Copy link
Contributor Author

@pglorio pglorio Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu @molbap sounds good. Should I go ahead and update mamba2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, but as I'm no maintainer I leave the decision to the others 👀

def __init__(self, hidden_size, group_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.group_size = group_size

def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
*prefix_dims, last_dim = hidden_states.shape
group_count = last_dim // self.group_size
hidden_states_group = hidden_states.view(*prefix_dims, group_count, self.group_size)
variance = hidden_states_group.pow(2).mean(-1, keepdim=True)
hidden_states_group = hidden_states_group * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states_group.view(*prefix_dims, group_count * self.group_size)
return self.weight * hidden_states.to(input_dtype)


class Zamba2RMSNorm(ZambaRMSNorm):
Expand Down Expand Up @@ -334,7 +351,9 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None):
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5)
self.norm = Zamba2RMSNormGated(
self.intermediate_size, group_size=self.intermediate_size // self.n_groups, eps=1e-5
)
self.D = nn.Parameter(torch.ones(self.num_heads))
self.D._no_weight_decay = True

Expand Down Expand Up @@ -896,7 +915,7 @@ class Zamba2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_flex_attn = True
_supports_sdpa = False
_supports_sdpa = True
_supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache
_is_stateful = True

Expand Down