-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support MuParametrization and MuTransfer #64
base: main
Are you sure you want to change the base?
Changes from all commits
b588300
e843c24
6877a53
4ff7462
0648a17
3f84774
0ef23e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,3 +20,4 @@ dependencies: | |
- evaluate | ||
- pytest | ||
- fair-esm | ||
- mup |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from models import * | ||
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1 @@ | ||||||
from apt import * | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .config import APTConfig | ||
from .model_pytorch import * |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,4 +1,5 @@ | ||||||
from typing import Optional, Tuple, Union | ||||||
import math | ||||||
import torch | ||||||
from torch import nn | ||||||
from torch.nn import CrossEntropyLoss | ||||||
|
@@ -8,6 +9,7 @@ | |||||
from transformers.pytorch_utils import Conv1D | ||||||
from transformers.activations import ACT2FN | ||||||
from transformers.utils import logging | ||||||
from mup import MuReadout, MuSharedReadout, normal_ | ||||||
from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding | ||||||
from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding | ||||||
from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor | ||||||
|
@@ -16,6 +18,7 @@ | |||||
|
||||||
logger = logging.get_logger(__name__) | ||||||
|
||||||
|
||||||
class APTAttention(GPT2Attention): | ||||||
def __init__(self, config, is_cross_attention=False, layer_idx=None): | ||||||
super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) | ||||||
|
@@ -42,6 +45,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): | |||||
f" {self.num_heads})." | ||||||
) | ||||||
|
||||||
# muP | ||||||
self.use_mup = config.use_mup | ||||||
self.attn_score = nn.Identity() # just for coordcheck | ||||||
self.query = nn.Identity() # just for coordcheck | ||||||
self.key = nn.Identity() # just for coordcheck | ||||||
self.value = nn.Identity() # just for coordcheck | ||||||
|
||||||
self.scale_attn_weights = config.scale_attn_weights | ||||||
self.is_cross_attention = is_cross_attention | ||||||
|
||||||
|
@@ -55,13 +65,20 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): | |||||
self.q_attn = Conv1D(self.embed_dim, self.embed_dim) | ||||||
else: | ||||||
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) | ||||||
|
||||||
self.c_proj = Conv1D(self.embed_dim, self.embed_dim) | ||||||
|
||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop) | ||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop) | ||||||
if self.use_mup: | ||||||
self.attn_dropout = nn.Identity() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we consider asserting that the dropout probabilities are set to 0 in this case (in configs)? |
||||||
self.resid_dropout = nn.Identity() | ||||||
else: | ||||||
self.attn_dropout = nn.Dropout(config.attn_pdrop) | ||||||
self.resid_dropout = nn.Dropout(config.resid_pdrop) | ||||||
|
||||||
self.pruned_heads = set() | ||||||
|
||||||
|
||||||
|
||||||
self.rot_emb=None | ||||||
if self.position_embedding == "rope": | ||||||
self.rot_emb=RotaryEmbedding(dim=self.head_dim) | ||||||
|
@@ -72,15 +89,23 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): | |||||
elif self.position_embedding=="dynamic_rope_scaling": | ||||||
self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) | ||||||
|
||||||
|
||||||
|
||||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None): | ||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) | ||||||
|
||||||
#muP | ||||||
if self.scale_attn_weights: | ||||||
attn_weights = attn_weights / torch.full( | ||||||
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device | ||||||
) | ||||||
if self.use_mup: | ||||||
attn_weights = attn_weights / torch.full( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we be multiplying by some There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch yes, thank you! Will update accordingly |
||||||
[], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device | ||||||
) | ||||||
else: | ||||||
attn_weights = attn_weights / torch.full( | ||||||
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device | ||||||
) | ||||||
|
||||||
attn_weights = self.attn_score(attn_weights) | ||||||
|
||||||
# Layer-wise attention scaling | ||||||
if self.scale_attn_by_inverse_layer_idx: | ||||||
|
@@ -97,7 +122,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia | |||||
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) | ||||||
if alibi_bias is not None: | ||||||
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] | ||||||
|
||||||
if attention_mask is not None: | ||||||
# Apply the attention mask | ||||||
attn_weights = attn_weights + attention_mask | ||||||
|
@@ -150,7 +175,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea | |||||
|
||||||
if alibi_bias is not None: | ||||||
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] | ||||||
|
||||||
if attention_mask is not None: | ||||||
# Apply the attention mask | ||||||
attn_weights = attn_weights + attention_mask | ||||||
|
@@ -171,7 +196,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea | |||||
|
||||||
return attn_output, attn_weights | ||||||
|
||||||
|
||||||
def forward( | ||||||
self, | ||||||
hidden_states: Optional[Tuple[torch.FloatTensor]], | ||||||
|
@@ -202,11 +227,15 @@ def forward( | |||||
query = self._split_heads(query, self.num_heads, self.head_dim) | ||||||
key = self._split_heads(key, self.num_heads, self.head_dim) | ||||||
value = self._split_heads(value, self.num_heads, self.head_dim) | ||||||
|
||||||
|
||||||
query = self.query(query) | ||||||
key = self.key(key) | ||||||
value = self.value(value) | ||||||
|
||||||
kv_seq_len=key.shape[-2] | ||||||
if layer_past is not None: | ||||||
kv_seq_len+=layer_past[0].shape[-2] | ||||||
|
||||||
# Apply rope embedding to query and key | ||||||
if self.rot_emb: | ||||||
bsz, q_len, _ = hidden_states.size() | ||||||
|
@@ -225,7 +254,6 @@ def forward( | |||||
key = torch.cat((past_key, key), dim=-2) | ||||||
value = torch.cat((past_value, value), dim=-2) | ||||||
|
||||||
|
||||||
if use_cache is True: | ||||||
present = (key, value) | ||||||
else: | ||||||
|
@@ -251,10 +279,20 @@ class APTMLP(nn.Module): | |||||
def __init__(self, intermediate_size, config): | ||||||
super().__init__() | ||||||
embed_dim = config.hidden_size | ||||||
|
||||||
#muP | ||||||
use_mup = config.use_mup | ||||||
|
||||||
self.c_fc = Conv1D(intermediate_size, embed_dim) | ||||||
|
||||||
self.c_proj = Conv1D(embed_dim, intermediate_size) | ||||||
|
||||||
self.act = ACT2FN[config.activation_function] | ||||||
self.dropout = nn.Dropout(config.resid_pdrop) | ||||||
|
||||||
if use_mup: | ||||||
self.dropout = nn.Identity() | ||||||
else: | ||||||
self.dropout = nn.Dropout(config.resid_pdrop) | ||||||
|
||||||
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: | ||||||
hidden_states = self.c_fc(hidden_states) | ||||||
|
@@ -270,6 +308,9 @@ def __init__(self, config, layer_idx=None): | |||||
hidden_size = config.hidden_size | ||||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size | ||||||
|
||||||
#muP | ||||||
self.use_mup = config.use_mup | ||||||
|
||||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | ||||||
self.attn = APTAttention(config, layer_idx=layer_idx) | ||||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) | ||||||
|
@@ -354,23 +395,32 @@ def __init__(self, config): | |||||
super().__init__(config) | ||||||
|
||||||
self.embed_dim = config.hidden_size | ||||||
use_mup = config.use_mup | ||||||
|
||||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim) | ||||||
|
||||||
self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned" | ||||||
|
||||||
if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": | ||||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) | ||||||
self.alibi = None | ||||||
elif self.position_embedding=="alibi": | ||||||
#muP TO DO: check proper behavior in alibi case | ||||||
maxpos = config.n_positions | ||||||
attn_heads = config.n_head | ||||||
alibi = create_alibi_tensor(attn_heads,maxpos) | ||||||
self.register_buffer('alibi',alibi) | ||||||
else: | ||||||
raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi') | ||||||
|
||||||
self.drop = nn.Dropout(config.embd_pdrop) | ||||||
|
||||||
#muP | ||||||
if use_mup: | ||||||
self.drop = nn.Identity() | ||||||
else: | ||||||
self.drop = nn.Dropout(config.embd_pdrop) | ||||||
|
||||||
self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) | ||||||
|
||||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) | ||||||
|
||||||
# Model parallel | ||||||
|
@@ -477,7 +527,7 @@ def forward( | |||||
hidden_states = inputs_embeds + position_embeds | ||||||
else: | ||||||
hidden_states = inputs_embeds | ||||||
|
||||||
|
||||||
if token_type_ids is not None: | ||||||
token_type_embeds = self.wte(token_type_ids) | ||||||
|
@@ -593,19 +643,78 @@ class APTLMHeadModel(GPT2PreTrainedModel): | |||||
def __init__(self, config): | ||||||
super().__init__(config) | ||||||
self.transformer = APTModel(config) | ||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | ||||||
|
||||||
# muP | ||||||
# TO DO: look into weight tying | ||||||
# TO DO: if weight tying is used, APTMuSharedReadout with the proper tied weight should be used instead | ||||||
self.lm_head = MuReadout(config.n_embd, | ||||||
config.vocab_size, | ||||||
bias=False, | ||||||
readout_zero_init=config.readout_zero_init, | ||||||
output_mult=config.output_mult) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I believe you have a typo here. |
||||||
|
||||||
# mup | ||||||
# note that this has to be run after mup.set_base_shape for it to work | ||||||
# see https://github.com/microsoft/mup#basic-usage | ||||||
# not sure if this is required here | ||||||
self.apply(self._init_weights) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it seems to me like we shouldn't call this here? As in your coordinate check example, you will have to call it again anyway (and only if you're using mup?)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, I think this might have been the result of some earlier testing and of forgetting to remove. Indeed this shouldn't have an effect so no reason to keep. Thanks! |
||||||
|
||||||
# Model parallel | ||||||
self.model_parallel = False | ||||||
self.device_map = None | ||||||
|
||||||
self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads, | ||||||
prepend_bos=True, | ||||||
append_eos=True, | ||||||
eos_idx=2) | ||||||
# mup implementation does not currently support this | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the dropout case, should we consider adding an assertion that we are not using mup with this in the configs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think this is a good idea! |
||||||
if config.contact_prediction_head: | ||||||
self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads, | ||||||
prepend_bos=True, | ||||||
append_eos=True, | ||||||
eos_idx=2) | ||||||
|
||||||
# Initialize weights and apply final processing | ||||||
self.post_init() | ||||||
|
||||||
# mup | ||||||
# general function for mup-specific weight initialization | ||||||
def _init_weights(self, module): | ||||||
if isinstance(module, (MuReadout, MuSharedReadout)) and self.config.readout_zero_init: | ||||||
module.weight.data.zero_() | ||||||
elif isinstance(module, (nn.Linear, Conv1D)): | ||||||
if hasattr(module.weight, 'infshape'): | ||||||
normal_(module.weight, mean=0.0, std=self.config.initializer_range) | ||||||
else: | ||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | ||||||
if module.bias is not None: | ||||||
module.bias.data.zero_() | ||||||
elif isinstance(module, nn.Embedding): | ||||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | ||||||
if module.padding_idx is not None: | ||||||
module.weight.data[module.padding_idx].zero_() | ||||||
elif isinstance(module, nn.LayerNorm): | ||||||
module.bias.data.zero_() | ||||||
module.weight.data.fill_(1.0) | ||||||
|
||||||
if isinstance(module, APTAttention): | ||||||
if hasattr(module, "q_attn"): | ||||||
# cross attention case | ||||||
if self.config.query_zero_init: | ||||||
# q_attn same as first third of c_attn in no cross attention case -- zero initialization | ||||||
self.q_attn.weight.data = 0 | ||||||
else: | ||||||
if self.config.query_zero_init: | ||||||
_, fanout = module.c_attn.weight.shape | ||||||
assert fanout % 3 == 0 | ||||||
module.c_attn.weight.data[:, :fanout//3] = 0 | ||||||
|
||||||
depth_std = self.config.initializer_range / math.sqrt(2 * self.config.n_layer) | ||||||
for name, p in module.named_parameters(): | ||||||
if "c_proj" in name and "weight" in name: | ||||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block | ||||||
if hasattr(p, 'infshape'): | ||||||
normal_(p, mean=0.0, std=depth_std) | ||||||
else: | ||||||
p.data.normal_(mean=0.0, std=depth_std) | ||||||
|
||||||
|
||||||
def forward( | ||||||
self, | ||||||
input_ids: Optional[torch.LongTensor] = None, | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need these to be relative imports? They didn't work as is for me.
Alternatively, instead of changing all of these to relative imports we can remove these lines and import them by specifying the full module paths in
test_coord_check.py