From b588300558462660f4ec0ea6caefb9d38b160983 Mon Sep 17 00:00:00 2001 From: NZ99 Date: Fri, 22 Dec 2023 16:14:07 +0100 Subject: [PATCH 1/7] add initial implementation of mup [to be checked] --- protein_lm/modeling/models/apt/config.py | 21 +++ .../modeling/models/apt/model_pytorch.py | 131 ++++++++++++++++-- 2 files changed, 144 insertions(+), 8 deletions(-) diff --git a/protein_lm/modeling/models/apt/config.py b/protein_lm/modeling/models/apt/config.py index 36f2c04..0fd9b10 100644 --- a/protein_lm/modeling/models/apt/config.py +++ b/protein_lm/modeling/models/apt/config.py @@ -11,6 +11,16 @@ def __init__( position_embedding="learned", tokenizer=None, max_sequence_length = 1024, + use_mup = False, + query_zero_init = True, + n_layer = None, + initializer_range = 0.02, + mup_init_scale = 1.0, + mup_output_temp = 1.0, + mup_attn_mult = 1.0, + mup_embedding_mult = 1.0, + mup_rp_embedding_mult = 1.0, + mup_width_scale = 2.0, **kwargs ): super().__init__(**kwargs) @@ -18,4 +28,15 @@ def __init__( self.position_embedding = position_embedding self.tokenizer = tokenizer self.max_sequence_length = max_sequence_length + + self.use_mup = use_mup + self.query_zero_init = query_zero_init, + self.n_layer = n_layer + self.initializer_range = initializer_range + self.mup_init_scale = mup_init_scale + self.mup_output_temp = mup_output_temp + self.mup_attn_mult = mup_attn_mult + self.mup_embedding_mult = mup_embedding_mult + self.mup_rp_embedding_mult = mup_rp_embedding_mult + self.mup_width_scale = mup_width_scale diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index f814519..87962e2 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple, Union +import math import torch from torch import nn from torch.nn import CrossEntropyLoss @@ -41,6 +42,9 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) + + # muP + self.use_mup = config.use_mup self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention @@ -53,15 +57,41 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): if self.is_cross_attention: self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + + #muP -- q_attn + if self.use_mup: + self.q_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) + self.q_attn.bias.zero_() + else: self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) - self.c_proj = Conv1D(self.embed_dim, self.embed_dim) - self.attn_dropout = nn.Dropout(config.attn_pdrop) - self.resid_dropout = nn.Dropout(config.resid_pdrop) + #muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487 + if self.use_mup: + if config.query_zero_init: + _, fanout = self.c_attn.weight.shape + self.c_attn.weight.data[:, :fanout//3] = 0 + self.c_attn.bias.zero_() + + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + #muP -- c_proj specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494 + if self.use_mup: + depth_std = config.initializer_range / math.sqrt(2 * config.n_layer) + self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(config.depth_std ** 2 / config.mup_width_scale)) + self.c_proj.bias.zero_() + + if self.use_mup: + self.attn_dropout = nn.Identity() + self.resid_dropout = nn.Identity() + else: + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) self.pruned_heads = set() + + self.rot_emb=None if self.position_embedding == "rope": self.rot_emb=RotaryEmbedding(dim=self.head_dim) @@ -76,8 +106,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: + + #muP + if self.use_mup: + attn_weights = attn_weights / torch.full( + [], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device + ) + elif self.scale_attn_weights: attn_weights = attn_weights / torch.full( [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device ) @@ -251,10 +286,31 @@ class APTMLP(nn.Module): def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size + + #muP + use_mup = config.use_mup + self.c_fc = Conv1D(intermediate_size, embed_dim) + + #muP -- matrix-like + if use_mup: + self.c_fc.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) + self.c_fc.bias.zero_() + self.c_proj = Conv1D(embed_dim, intermediate_size) + + #muP -- matrix-like, c_proj-specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494 + if use_mup: + depth_std = config.initializer_range / math.sqrt(2 * config.n_layer) + self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.mup_width_scale)) + self.c_proj.bias.zero_() + self.act = ACT2FN[config.activation_function] - self.dropout = nn.Dropout(config.resid_pdrop) + + if use_mup: + self.dropout = nn.Identity() + else: + self.dropout = nn.Dropout(config.resid_pdrop) def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: hidden_states = self.c_fc(hidden_states) @@ -270,14 +326,34 @@ def __init__(self, config, layer_idx=None): hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + #muP + use_mup = config.use_mup + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + #muP -- vector-like + if self.use_mup: + self.ln_1.weight.data.fill_(1.0) + self.ln_1.bias.data.zero_() + self.attn = APTAttention(config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + #muP -- vector-like + if use_mup: + self.ln_2.weight.data.fill_(1.0) + self.ln_2.bias.data.zero_() + if config.add_cross_attention: + #muP TO DO: check proper behavior in case of crossattention self.crossattention = APTAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + #muP -- vector-like + if use_mup: + self.ln_cross_attn.weight.data.fill_(1.0) + self.ln_cross_attn.bias.data.zero_() + self.mlp = APTMLP(inner_dim, config) def forward( @@ -353,15 +429,38 @@ class APTModel(GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) - self.embed_dim = config.hidden_size + self.embed_dim = config.hidden_sizeù + use_mup = config.use_mup self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + + #muP -- vector-like, zero if zero init or mantained regardless of width + if use_mup: + if config.wte_zero_init: + self.wte.weight.data.zero_() + else: + self.wte.weight.data.normal_(mean=0.0, std=config.initializer_range) + + if self.wte.padding_idx is not None: + self.wte.weight.data[self.wte.padding_idx].zero_() + self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned" if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + #muP -- vector-like, constant regardless of width + #muP TO DO: check proper behavior in rope & rerope case + if self.use_mup: + self.wpe.weight.data.normal_(0.0, std=config.initializer_range) + + if self.wpe.padding_idx is not None: + self.wpe.weight.data[self.wte.padding_idx].zero_() + self.alibi = None elif self.position_embedding=="alibi": + #muP TO DO: check proper behavior in alibi case + maxpos = config.n_positions attn_heads = config.n_head alibi = create_alibi_tensor(attn_heads,maxpos) @@ -369,10 +468,21 @@ def __init__(self, config): else: raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi') - self.drop = nn.Dropout(config.embd_pdrop) + #muP + if use_mup: + self.drop = nn.Identity() + else: + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + #muP -- vector-like + if use_mup: + self.ln_f.weight.data.fill_(1.0) + self.ln_f.bias.data.zero_() + # Model parallel self.model_parallel = False self.device_map = None @@ -474,6 +584,7 @@ def forward( if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": position_embeds = self.wpe(position_ids) + position_embeds.mul_(self.mup_rp_embedding_mult) hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds @@ -593,6 +704,10 @@ class APTLMHeadModel(GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = APTModel(config) + + #muP TO DO: check proper behavior for LM head, nothing should be done (?) + #see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L472 + #see also table 8's caption in https://arxiv.org/pdf/2203.03466.pdf self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Model parallel From e843c245ddf4a6f7693327e0f24a9ea3e064d806 Mon Sep 17 00:00:00 2001 From: NZ99 Date: Thu, 11 Jan 2024 11:47:02 +0100 Subject: [PATCH 2/7] update mup implementation based on othertea's review --- protein_lm.yml | 1 + protein_lm/modeling/models/apt/config.py | 4 +- .../modeling/models/apt/model_pytorch.py | 50 ++++++++++++------- protein_lm/modeling/scripts/train.py | 2 + protein_lm_cuda.yml | 1 + 5 files changed, 37 insertions(+), 21 deletions(-) diff --git a/protein_lm.yml b/protein_lm.yml index 5ce09c6..aba06ef 100644 --- a/protein_lm.yml +++ b/protein_lm.yml @@ -20,3 +20,4 @@ dependencies: - evaluate - pytest - fair-esm + - mup diff --git a/protein_lm/modeling/models/apt/config.py b/protein_lm/modeling/models/apt/config.py index 0fd9b10..f19db8d 100644 --- a/protein_lm/modeling/models/apt/config.py +++ b/protein_lm/modeling/models/apt/config.py @@ -16,7 +16,7 @@ def __init__( n_layer = None, initializer_range = 0.02, mup_init_scale = 1.0, - mup_output_temp = 1.0, + mup_output_mult = 1.0, mup_attn_mult = 1.0, mup_embedding_mult = 1.0, mup_rp_embedding_mult = 1.0, @@ -34,7 +34,7 @@ def __init__( self.n_layer = n_layer self.initializer_range = initializer_range self.mup_init_scale = mup_init_scale - self.mup_output_temp = mup_output_temp + self.mup_output_mult = mup_output_mult self.mup_attn_mult = mup_attn_mult self.mup_embedding_mult = mup_embedding_mult self.mup_rp_embedding_mult = mup_rp_embedding_mult diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index 87962e2..afc83c7 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -9,6 +9,7 @@ from transformers.pytorch_utils import Conv1D from transformers.activations import ACT2FN from transformers.utils import logging +from mup import MuReadout from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor @@ -45,6 +46,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): # muP self.use_mup = config.use_mup + self.mup_attn_mult = config.mup_attn_mult self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention @@ -58,27 +60,37 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) self.q_attn = Conv1D(self.embed_dim, self.embed_dim) - #muP -- q_attn + #muP -- c_attn & q_attn, cross attention case if self.use_mup: - self.q_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) + # default case -- mup initialization for c_attn and q_attn + self.c_attn.weight.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) + self.c_attn.bias.zero_() + self.q_attn.weight.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) self.q_attn.bias.zero_() - + if config.query_zero_init: + # q_attn same as last third of c_attn in no cross attention case -- zero initialization + self.q_attn.weight.data = 0 + else: self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) - - #muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487 - if self.use_mup: - if config.query_zero_init: - _, fanout = self.c_attn.weight.shape - self.c_attn.weight.data[:, :fanout//3] = 0 + + #muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487 + if self.use_mup: + # default case -- mup initialization for c_attn + self.c_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) self.c_attn.bias.zero_() + if config.query_zero_init: + # last third of c_attn -- zero initialization + _, fanout = self.c_attn.weight.shape + self.c_attn.weight.data[:, :fanout//3] = 0 + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) #muP -- c_proj specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494 if self.use_mup: depth_std = config.initializer_range / math.sqrt(2 * config.n_layer) - self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(config.depth_std ** 2 / config.mup_width_scale)) + self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.mup_width_scale)) self.c_proj.bias.zero_() if self.use_mup: @@ -109,9 +121,10 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia #muP if self.use_mup: - attn_weights = attn_weights / torch.full( - [], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device - ) + if self.mup_attn_mult: + attn_weights = attn_weights / torch.full( + [], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device + ) elif self.scale_attn_weights: attn_weights = attn_weights / torch.full( [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device @@ -260,7 +273,6 @@ def forward( key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) - if use_cache is True: present = (key, value) else: @@ -429,7 +441,7 @@ class APTModel(GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) - self.embed_dim = config.hidden_sizeù + self.embed_dim = config.hidden_size use_mup = config.use_mup self.wte = nn.Embedding(config.vocab_size, self.embed_dim) @@ -705,10 +717,10 @@ def __init__(self, config): super().__init__(config) self.transformer = APTModel(config) - #muP TO DO: check proper behavior for LM head, nothing should be done (?) - #see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L472 - #see also table 8's caption in https://arxiv.org/pdf/2203.03466.pdf - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # muP + # TO DO: look into weight tying. if using weight tying, should we use MuSharedReadout? + # see https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L59-L68 + self.lm_head = MuReadout(config.n_embd, config.vocab_size, bias=False, output_mult=config.output_mult, width_mult=config.mup_width_scale) # Model parallel self.model_parallel = False diff --git a/protein_lm/modeling/scripts/train.py b/protein_lm/modeling/scripts/train.py index 7c2d555..e993f2c 100644 --- a/protein_lm/modeling/scripts/train.py +++ b/protein_lm/modeling/scripts/train.py @@ -53,6 +53,8 @@ def train( config_dict["wandb"], ) + # TO DO: add support for mup's optimizers in case use_mup is used, see e.g. https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/optim.py + # available via mup.optim trainer = Trainer( model=model, args=training_args, diff --git a/protein_lm_cuda.yml b/protein_lm_cuda.yml index 3ce6cb4..c6511c8 100644 --- a/protein_lm_cuda.yml +++ b/protein_lm_cuda.yml @@ -20,3 +20,4 @@ dependencies: - evaluate - pytest - fair-esm + - mup \ No newline at end of file From 6877a538eaedca06c51d84d9f9cb2fde1441b56d Mon Sep 17 00:00:00 2001 From: NZ99 Date: Thu, 11 Jan 2024 14:47:30 +0100 Subject: [PATCH 3/7] integrate mureadout and sharedmureadout, document & edit mup config --- protein_lm.yml | 1 - protein_lm/modeling/models/apt/config.py | 26 +++++--- .../modeling/models/apt/model_pytorch.py | 63 +++++++++++++++---- protein_lm_cuda.yml | 3 +- 4 files changed, 70 insertions(+), 23 deletions(-) diff --git a/protein_lm.yml b/protein_lm.yml index aba06ef..5ce09c6 100644 --- a/protein_lm.yml +++ b/protein_lm.yml @@ -20,4 +20,3 @@ dependencies: - evaluate - pytest - fair-esm - - mup diff --git a/protein_lm/modeling/models/apt/config.py b/protein_lm/modeling/models/apt/config.py index f19db8d..3fd7ef6 100644 --- a/protein_lm/modeling/models/apt/config.py +++ b/protein_lm/modeling/models/apt/config.py @@ -11,16 +11,26 @@ def __init__( position_embedding="learned", tokenizer=None, max_sequence_length = 1024, - use_mup = False, query_zero_init = True, n_layer = None, initializer_range = 0.02, - mup_init_scale = 1.0, + + # whether to use MuParametrization + use_mup = False, + + # whether to initialize the input embedding layer with zero-initialization + wte_zero_init = True, + + # the output layer multiplier if mup is used, see https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L56 mup_output_mult = 1.0, - mup_attn_mult = 1.0, + + # whether to scale attention weights by the key dimension instead of its square root + mup_attn_mult = True, + + # the positional embedding multiplier if mup is used mup_embedding_mult = 1.0, - mup_rp_embedding_mult = 1.0, - mup_width_scale = 2.0, + + width_mult_for_weights = 2.0, **kwargs ): super().__init__(**kwargs) @@ -30,13 +40,13 @@ def __init__( self.max_sequence_length = max_sequence_length self.use_mup = use_mup + self.wte_zero_init = wte_zero_init self.query_zero_init = query_zero_init, self.n_layer = n_layer self.initializer_range = initializer_range - self.mup_init_scale = mup_init_scale self.mup_output_mult = mup_output_mult self.mup_attn_mult = mup_attn_mult self.mup_embedding_mult = mup_embedding_mult - self.mup_rp_embedding_mult = mup_rp_embedding_mult - self.mup_width_scale = mup_width_scale + self.width_mult_for_weights = width_mult_for_weights + diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index afc83c7..59690ca 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -9,7 +9,6 @@ from transformers.pytorch_utils import Conv1D from transformers.activations import ACT2FN from transformers.utils import logging -from mup import MuReadout from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor @@ -18,6 +17,47 @@ logger = logging.get_logger(__name__) +class APTMuReadout(nn.Linear): + '''Drop-in replacement for all output linear layers. + + An "output" linear layer is one that maps from a width dimension (e.g., + `d_model` in a Transformer) to a non-width dimension (e.g., vocab size). + + This layer implements the version of μP with a 1/width multiplier and a + constant variance initialization for both weights and biases. + ''' + def __init__(self, *args, readout_zero_init=False, output_mult=1.0, width_mult=1.0, **kwargs): + self.output_mult = output_mult + self.readout_zero_init = readout_zero_init + self.width_mult_val = width_mult + super().__init__(*args, **kwargs) + + def width_mult(self): + return self.width_mult_val + + def reset_parameters(self) -> None: + if self.readout_zero_init: + self.weight.data[:] = 0 + if self.bias is not None: + self.bias.data[:] = 0 + else: + super().reset_parameters() + + def forward(self, x): + return super().forward( + self.output_mult * x / self.width_mult()) + +class APTMuSharedReadout(APTMuReadout): + '''`APTMuReadout` with weights shared with an `nn.Embedding` layer. + + Inputs: + weight: should be weight of an `nn.Embedding` layer + other inputs are fed to `MuReadout` + ''' + def __init__(self, weight, bias=True, **kwargs): + super().__init__(*weight.shape, bias=bias, **kwargs) + self.weight = weight + class APTAttention(GPT2Attention): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) @@ -63,9 +103,9 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): #muP -- c_attn & q_attn, cross attention case if self.use_mup: # default case -- mup initialization for c_attn and q_attn - self.c_attn.weight.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) + self.c_attn.weight.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.width_mult_for_weights)) self.c_attn.bias.zero_() - self.q_attn.weight.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) + self.q_attn.weight.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.width_mult_for_weights)) self.q_attn.bias.zero_() if config.query_zero_init: # q_attn same as last third of c_attn in no cross attention case -- zero initialization @@ -77,7 +117,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): #muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487 if self.use_mup: # default case -- mup initialization for c_attn - self.c_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) + self.c_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.width_mult_for_weights)) self.c_attn.bias.zero_() if config.query_zero_init: # last third of c_attn -- zero initialization @@ -90,7 +130,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): #muP -- c_proj specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494 if self.use_mup: depth_std = config.initializer_range / math.sqrt(2 * config.n_layer) - self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.mup_width_scale)) + self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.width_mult_for_weights)) self.c_proj.bias.zero_() if self.use_mup: @@ -306,7 +346,7 @@ def __init__(self, intermediate_size, config): #muP -- matrix-like if use_mup: - self.c_fc.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.mup_width_scale)) + self.c_fc.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.width_mult_for_weights)) self.c_fc.bias.zero_() self.c_proj = Conv1D(embed_dim, intermediate_size) @@ -314,7 +354,7 @@ def __init__(self, intermediate_size, config): #muP -- matrix-like, c_proj-specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494 if use_mup: depth_std = config.initializer_range / math.sqrt(2 * config.n_layer) - self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.mup_width_scale)) + self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.width_mult_for_weights)) self.c_proj.bias.zero_() self.act = ACT2FN[config.activation_function] @@ -446,7 +486,7 @@ def __init__(self, config): self.wte = nn.Embedding(config.vocab_size, self.embed_dim) - #muP -- vector-like, zero if zero init or mantained regardless of width + #muP -- vector-like, zero if zero init or mantained constant regardless of width if use_mup: if config.wte_zero_init: self.wte.weight.data.zero_() @@ -596,7 +636,7 @@ def forward( if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": position_embeds = self.wpe(position_ids) - position_embeds.mul_(self.mup_rp_embedding_mult) + position_embeds.mul_(self.mup_embedding_mult) hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds @@ -718,9 +758,8 @@ def __init__(self, config): self.transformer = APTModel(config) # muP - # TO DO: look into weight tying. if using weight tying, should we use MuSharedReadout? - # see https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L59-L68 - self.lm_head = MuReadout(config.n_embd, config.vocab_size, bias=False, output_mult=config.output_mult, width_mult=config.mup_width_scale) + # TO DO: look into weight tying. if using weight tying, APTMuSharedReadout should be used instead + self.lm_head = APTMuReadout(config.n_embd, config.vocab_size, bias=False, output_mult=config.output_mult, width_mult=config.width_mult_for_weights) # Model parallel self.model_parallel = False diff --git a/protein_lm_cuda.yml b/protein_lm_cuda.yml index c6511c8..aef0b60 100644 --- a/protein_lm_cuda.yml +++ b/protein_lm_cuda.yml @@ -19,5 +19,4 @@ dependencies: - accelerate - evaluate - pytest - - fair-esm - - mup \ No newline at end of file + - fair-esm \ No newline at end of file From 4ff7462aae8252f807f0edeebca725412066e1e2 Mon Sep 17 00:00:00 2001 From: NZ99 Date: Thu, 11 Jan 2024 15:46:12 +0100 Subject: [PATCH 4/7] add zero-initialization to the readout layer --- protein_lm/modeling/models/apt/config.py | 4 ++++ protein_lm/modeling/models/apt/model_pytorch.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/protein_lm/modeling/models/apt/config.py b/protein_lm/modeling/models/apt/config.py index 3fd7ef6..68cce03 100644 --- a/protein_lm/modeling/models/apt/config.py +++ b/protein_lm/modeling/models/apt/config.py @@ -21,6 +21,9 @@ def __init__( # whether to initialize the input embedding layer with zero-initialization wte_zero_init = True, + # whether to initialize the output (readout) layer with zero-initialization + readout_zero_init = True, + # the output layer multiplier if mup is used, see https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L56 mup_output_mult = 1.0, @@ -44,6 +47,7 @@ def __init__( self.query_zero_init = query_zero_init, self.n_layer = n_layer self.initializer_range = initializer_range + self.readout_zero_init = readout_zero_init self.mup_output_mult = mup_output_mult self.mup_attn_mult = mup_attn_mult self.mup_embedding_mult = mup_embedding_mult diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index 59690ca..a330d39 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -759,7 +759,11 @@ def __init__(self, config): # muP # TO DO: look into weight tying. if using weight tying, APTMuSharedReadout should be used instead - self.lm_head = APTMuReadout(config.n_embd, config.vocab_size, bias=False, output_mult=config.output_mult, width_mult=config.width_mult_for_weights) + self.lm_head = APTMuReadout(config.n_embd, config.vocab_size, + bias=False, + readout_zero_init=config.readout_zero_init, + output_mult=config.output_mult, + width_mult=config.width_mult_for_weights) # Model parallel self.model_parallel = False From 0648a174a3dc69e1a0f1017c0335940d1d0b9fa5 Mon Sep 17 00:00:00 2001 From: NZ99 Date: Thu, 25 Jan 2024 10:25:00 +0100 Subject: [PATCH 5/7] refactor mup implementation, add coordinate checking test --- protein_lm.yml | 1 + protein_lm/modeling/__init__.py | 1 + protein_lm/modeling/models/__init__.py | 1 + protein_lm/modeling/models/apt/__init__.py | 2 + protein_lm/modeling/models/apt/config.py | 24 +- .../modeling/models/apt/model_pytorch.py | 249 +++++++----------- protein_lm/tests/test_coord_check.py | 37 +++ protein_lm_cuda.yml | 3 +- 8 files changed, 145 insertions(+), 173 deletions(-) create mode 100644 protein_lm/tests/test_coord_check.py diff --git a/protein_lm.yml b/protein_lm.yml index 5ce09c6..aba06ef 100644 --- a/protein_lm.yml +++ b/protein_lm.yml @@ -20,3 +20,4 @@ dependencies: - evaluate - pytest - fair-esm + - mup diff --git a/protein_lm/modeling/__init__.py b/protein_lm/modeling/__init__.py index e69de29..173d567 100644 --- a/protein_lm/modeling/__init__.py +++ b/protein_lm/modeling/__init__.py @@ -0,0 +1 @@ +from models import * \ No newline at end of file diff --git a/protein_lm/modeling/models/__init__.py b/protein_lm/modeling/models/__init__.py index e69de29..20b24e2 100644 --- a/protein_lm/modeling/models/__init__.py +++ b/protein_lm/modeling/models/__init__.py @@ -0,0 +1 @@ +from apt import * \ No newline at end of file diff --git a/protein_lm/modeling/models/apt/__init__.py b/protein_lm/modeling/models/apt/__init__.py index e69de29..a098708 100644 --- a/protein_lm/modeling/models/apt/__init__.py +++ b/protein_lm/modeling/models/apt/__init__.py @@ -0,0 +1,2 @@ +from config import APTConfig +from model_pytorch import * \ No newline at end of file diff --git a/protein_lm/modeling/models/apt/config.py b/protein_lm/modeling/models/apt/config.py index 68cce03..9d592d2 100644 --- a/protein_lm/modeling/models/apt/config.py +++ b/protein_lm/modeling/models/apt/config.py @@ -13,27 +13,18 @@ def __init__( max_sequence_length = 1024, query_zero_init = True, n_layer = None, + contact_prediction_head = False, initializer_range = 0.02, - # whether to use MuParametrization use_mup = False, - - # whether to initialize the input embedding layer with zero-initialization - wte_zero_init = True, - # whether to initialize the output (readout) layer with zero-initialization readout_zero_init = True, - # the output layer multiplier if mup is used, see https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L56 mup_output_mult = 1.0, - - # whether to scale attention weights by the key dimension instead of its square root - mup_attn_mult = True, - - # the positional embedding multiplier if mup is used - mup_embedding_mult = 1.0, - width_mult_for_weights = 2.0, + # rope + rope_theta = 0.0, + rope_scaling_factor=1, **kwargs ): super().__init__(**kwargs) @@ -43,14 +34,13 @@ def __init__( self.max_sequence_length = max_sequence_length self.use_mup = use_mup - self.wte_zero_init = wte_zero_init self.query_zero_init = query_zero_init, self.n_layer = n_layer + self.contact_prediction_head = contact_prediction_head self.initializer_range = initializer_range self.readout_zero_init = readout_zero_init self.mup_output_mult = mup_output_mult - self.mup_attn_mult = mup_attn_mult - self.mup_embedding_mult = mup_embedding_mult self.width_mult_for_weights = width_mult_for_weights - + self.rope_theta = rope_theta + self.rope_scaling_factor = rope_scaling_factor diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index a330d39..6634b68 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -9,6 +9,7 @@ from transformers.pytorch_utils import Conv1D from transformers.activations import ACT2FN from transformers.utils import logging +from mup import MuReadout, MuSharedReadout, normal_ from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor @@ -17,46 +18,6 @@ logger = logging.get_logger(__name__) -class APTMuReadout(nn.Linear): - '''Drop-in replacement for all output linear layers. - - An "output" linear layer is one that maps from a width dimension (e.g., - `d_model` in a Transformer) to a non-width dimension (e.g., vocab size). - - This layer implements the version of μP with a 1/width multiplier and a - constant variance initialization for both weights and biases. - ''' - def __init__(self, *args, readout_zero_init=False, output_mult=1.0, width_mult=1.0, **kwargs): - self.output_mult = output_mult - self.readout_zero_init = readout_zero_init - self.width_mult_val = width_mult - super().__init__(*args, **kwargs) - - def width_mult(self): - return self.width_mult_val - - def reset_parameters(self) -> None: - if self.readout_zero_init: - self.weight.data[:] = 0 - if self.bias is not None: - self.bias.data[:] = 0 - else: - super().reset_parameters() - - def forward(self, x): - return super().forward( - self.output_mult * x / self.width_mult()) - -class APTMuSharedReadout(APTMuReadout): - '''`APTMuReadout` with weights shared with an `nn.Embedding` layer. - - Inputs: - weight: should be weight of an `nn.Embedding` layer - other inputs are fed to `MuReadout` - ''' - def __init__(self, weight, bias=True, **kwargs): - super().__init__(*weight.shape, bias=bias, **kwargs) - self.weight = weight class APTAttention(GPT2Attention): def __init__(self, config, is_cross_attention=False, layer_idx=None): @@ -83,10 +44,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads})." ) - + # muP self.use_mup = config.use_mup - self.mup_attn_mult = config.mup_attn_mult + self.attn_score = nn.Identity() # just for coordcheck + self.query = nn.Identity() # just for coordcheck + self.key = nn.Identity() # just for coordcheck + self.value = nn.Identity() # just for coordcheck self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention @@ -99,40 +63,11 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): if self.is_cross_attention: self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) self.q_attn = Conv1D(self.embed_dim, self.embed_dim) - - #muP -- c_attn & q_attn, cross attention case - if self.use_mup: - # default case -- mup initialization for c_attn and q_attn - self.c_attn.weight.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.width_mult_for_weights)) - self.c_attn.bias.zero_() - self.q_attn.weight.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.width_mult_for_weights)) - self.q_attn.bias.zero_() - if config.query_zero_init: - # q_attn same as last third of c_attn in no cross attention case -- zero initialization - self.q_attn.weight.data = 0 - else: self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) - - #muP -- c_attn specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L487 - if self.use_mup: - # default case -- mup initialization for c_attn - self.c_attn.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.width_mult_for_weights)) - self.c_attn.bias.zero_() - if config.query_zero_init: - # last third of c_attn -- zero initialization - _, fanout = self.c_attn.weight.shape - self.c_attn.weight.data[:, :fanout//3] = 0 - self.c_proj = Conv1D(self.embed_dim, self.embed_dim) - - #muP -- c_proj specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494 - if self.use_mup: - depth_std = config.initializer_range / math.sqrt(2 * config.n_layer) - self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.width_mult_for_weights)) - self.c_proj.bias.zero_() - + if self.use_mup: self.attn_dropout = nn.Identity() self.resid_dropout = nn.Identity() @@ -142,7 +77,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.pruned_heads = set() - + self.rot_emb=None if self.position_embedding == "rope": @@ -154,21 +89,23 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): elif self.position_embedding=="dynamic_rope_scaling": self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) - + def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) - + #muP - if self.use_mup: - if self.mup_attn_mult: + if self.scale_attn_weights: + if self.use_mup: attn_weights = attn_weights / torch.full( [], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device ) - elif self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device - ) + else: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + attn_weights = self.attn_score(attn_weights) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -185,7 +122,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) if alibi_bias is not None: attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] - + if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask @@ -238,7 +175,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea if alibi_bias is not None: attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] - + if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask @@ -259,7 +196,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea return attn_output, attn_weights - + def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], @@ -290,11 +227,15 @@ def forward( query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) - + + query = self.query(query) + key = self.key(key) + value = self.value(value) + kv_seq_len=key.shape[-2] if layer_past is not None: kv_seq_len+=layer_past[0].shape[-2] - + # Apply rope embedding to query and key if self.rot_emb: bsz, q_len, _ = hidden_states.size() @@ -344,19 +285,8 @@ def __init__(self, intermediate_size, config): self.c_fc = Conv1D(intermediate_size, embed_dim) - #muP -- matrix-like - if use_mup: - self.c_fc.weight.data.normal_(mean=0.0, std=math.sqrt((config.initializer_range ** 2) / config.width_mult_for_weights)) - self.c_fc.bias.zero_() - self.c_proj = Conv1D(embed_dim, intermediate_size) - #muP -- matrix-like, c_proj-specific, see https://github.com/cofe-ai/Mu-scaling/blob/a3d13c3eaa02e28cc6ef95116358b128424f9fe4/modeling/modeling_gpt2_mup.py#L494 - if use_mup: - depth_std = config.initializer_range / math.sqrt(2 * config.n_layer) - self.c_proj.weight.data.normal_(mean=0.0, std=math.sqrt(depth_std ** 2 / config.width_mult_for_weights)) - self.c_proj.bias.zero_() - self.act = ACT2FN[config.activation_function] if use_mup: @@ -379,33 +309,16 @@ def __init__(self, config, layer_idx=None): inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size #muP - use_mup = config.use_mup + self.use_mup = config.use_mup self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - - #muP -- vector-like - if self.use_mup: - self.ln_1.weight.data.fill_(1.0) - self.ln_1.bias.data.zero_() - self.attn = APTAttention(config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - #muP -- vector-like - if use_mup: - self.ln_2.weight.data.fill_(1.0) - self.ln_2.bias.data.zero_() - if config.add_cross_attention: - #muP TO DO: check proper behavior in case of crossattention self.crossattention = APTAttention(config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - #muP -- vector-like - if use_mup: - self.ln_cross_attn.weight.data.fill_(1.0) - self.ln_cross_attn.bias.data.zero_() - self.mlp = APTMLP(inner_dim, config) def forward( @@ -485,41 +398,21 @@ def __init__(self, config): use_mup = config.use_mup self.wte = nn.Embedding(config.vocab_size, self.embed_dim) - - #muP -- vector-like, zero if zero init or mantained constant regardless of width - if use_mup: - if config.wte_zero_init: - self.wte.weight.data.zero_() - else: - self.wte.weight.data.normal_(mean=0.0, std=config.initializer_range) - - if self.wte.padding_idx is not None: - self.wte.weight.data[self.wte.padding_idx].zero_() self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned" - + if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - - #muP -- vector-like, constant regardless of width - #muP TO DO: check proper behavior in rope & rerope case - if self.use_mup: - self.wpe.weight.data.normal_(0.0, std=config.initializer_range) - - if self.wpe.padding_idx is not None: - self.wpe.weight.data[self.wte.padding_idx].zero_() - self.alibi = None elif self.position_embedding=="alibi": - #muP TO DO: check proper behavior in alibi case - + #muP TO DO: check proper behavior in alibi case maxpos = config.n_positions attn_heads = config.n_head alibi = create_alibi_tensor(attn_heads,maxpos) self.register_buffer('alibi',alibi) else: raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi') - + #muP if use_mup: self.drop = nn.Identity() @@ -530,11 +423,6 @@ def __init__(self, config): self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) - #muP -- vector-like - if use_mup: - self.ln_f.weight.data.fill_(1.0) - self.ln_f.bias.data.zero_() - # Model parallel self.model_parallel = False self.device_map = None @@ -636,11 +524,10 @@ def forward( if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": position_embeds = self.wpe(position_ids) - position_embeds.mul_(self.mup_embedding_mult) hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds - + if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) @@ -758,24 +645,76 @@ def __init__(self, config): self.transformer = APTModel(config) # muP - # TO DO: look into weight tying. if using weight tying, APTMuSharedReadout should be used instead - self.lm_head = APTMuReadout(config.n_embd, config.vocab_size, - bias=False, - readout_zero_init=config.readout_zero_init, - output_mult=config.output_mult, - width_mult=config.width_mult_for_weights) + # TO DO: look into weight tying + # TO DO: if weight tying is used, APTMuSharedReadout with the proper tied weight should be used instead + self.lm_head = MuReadout(config.n_embd, + config.vocab_size, + bias=False, + readout_zero_init=config.readout_zero_init, + output_mult=config.output_mult) + + # mup + # note that this has to be run after mup.set_base_shape for it to work + # see https://github.com/microsoft/mup#basic-usage + # not sure if this is required here + self.apply(self._init_weights) # Model parallel self.model_parallel = False self.device_map = None - self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads, - prepend_bos=True, - append_eos=True, - eos_idx=2) + # mup implementation does not currently support this + if config.contact_prediction_head: + self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads, + prepend_bos=True, + append_eos=True, + eos_idx=2) # Initialize weights and apply final processing self.post_init() + + # mup + # general function for mup-specific weight initialization + def _init_weights(self, module): + if isinstance(module, (MuReadout, MuSharedReadout)) and self.config.readout_zero_init: + module.weight.data.zero_() + elif isinstance(module, (nn.Linear, Conv1D)): + if hasattr(module.weight, 'infshape'): + normal_(module.weight, mean=0.0, std=self.config.initializer_range) + else: + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, APTAttention): + if hasattr(module, "q_attn"): + # cross attention case + if self.config.query_zero_init: + # q_attn same as last third of c_attn in no cross attention case -- zero initialization + self.q_attn.weight.data = 0 + else: + if self.config.query_zero_init: + _, fanout = module.c_attn.weight.shape + assert fanout % 3 == 0 + module.c_attn.weight.data[:, :fanout//3] = 0 + + depth_std = self.config.initializer_range / math.sqrt(2 * self.config.n_layer) + for name, p in module.named_parameters(): + if "c_proj" in name and "weight" in name: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + if hasattr(p, 'infshape'): + normal_(p, mean=0.0, std=depth_std) + else: + p.data.normal_(mean=0.0, std=depth_std) + + def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/protein_lm/tests/test_coord_check.py b/protein_lm/tests/test_coord_check.py new file mode 100644 index 0000000..e463d5a --- /dev/null +++ b/protein_lm/tests/test_coord_check.py @@ -0,0 +1,37 @@ +from mup import make_base_shapes, set_base_shapes, make_base_shapes, set_base_shapes, get_shapes, MuAdam, MuSGD, MuAdamW +from mup.coord_check import get_coord_data, plot_coord_data +from functools import partial +import torch +from protein_lm.modeling import APTConfig, APTLMHeadModel + + +# not sure how to leverage pytest in the context of coordinate checking +# this is because visual inspection of coordinate checking results is necessary +# the test will generate coordinate checking results to a test_results directory for now + +if __name__ == "__main__": + delta_model = APTLMHeadModel(config=APTConfig(n_embd=200, n_layer=8, num_attention_heads=10, n_inner=200, use_mup=True)) + delta_model.apply(delta_model._init_weights) + + base_model = APTLMHeadModel(config=APTConfig(n_embd=1, n_layer=8, num_attention_heads=1, n_inner=1, use_mup=True)) + base_model.apply(base_model._init_weights) + + def get_mup_apt_model(width): + model = APTLMHeadModel(config=APTConfig(n_embd=width, n_layer=8, num_attention_heads=width//16, n_inner=width, use_mup=True)) + return model + + def set_up_mup_apt_model(width): + model = set_base_shapes(get_mup_apt_model(width), base_model, delta=delta_model) + model.apply(model._init_weights) + return model + + def get_mup_lazy_model(width): + return lambda: set_up_mup_apt_model(width) + + models = {256: get_mup_lazy_model(256), 512: get_mup_lazy_model(512), 1024: get_mup_lazy_model(1024), 2048: get_mup_lazy_model(2048)} + + input_ids = torch.randint(low=0, high=50257, size=(1, 256)).to(torch.int64) + labels = torch.randint(low=0, high=50257, size=(1, 256)).to(torch.int64) + dataloader=[{'input_ids': input_ids, 'labels': labels}] + df = get_coord_data(models, dataloader, optimizer='sgd', lr=0.1, dict_in_out=True, output_name='loss', cuda=True, nsteps=10, nseeds=10) + plot_coord_data(df, legend=None, save_to='test_results/apt_coordcheck.jpg') \ No newline at end of file diff --git a/protein_lm_cuda.yml b/protein_lm_cuda.yml index aef0b60..c6511c8 100644 --- a/protein_lm_cuda.yml +++ b/protein_lm_cuda.yml @@ -19,4 +19,5 @@ dependencies: - accelerate - evaluate - pytest - - fair-esm \ No newline at end of file + - fair-esm + - mup \ No newline at end of file From 3f84774a7aa72e1545c3ebac46af7630281ed4e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Zanichelli?= <61413132+NZ99@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:35:04 +0100 Subject: [PATCH 6/7] Update protein_lm/modeling/models/apt/__init__.py Co-authored-by: othertea <124535597+othertea@users.noreply.github.com> --- protein_lm/modeling/models/apt/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/protein_lm/modeling/models/apt/__init__.py b/protein_lm/modeling/models/apt/__init__.py index a098708..b03a5b9 100644 --- a/protein_lm/modeling/models/apt/__init__.py +++ b/protein_lm/modeling/models/apt/__init__.py @@ -1,2 +1,2 @@ -from config import APTConfig -from model_pytorch import * \ No newline at end of file +from .config import APTConfig +from .model_pytorch import * \ No newline at end of file From 0ef23e603e706425b11d38253eead6b87dd52e08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niccol=C3=B2=20Zanichelli?= <61413132+NZ99@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:37:15 +0100 Subject: [PATCH 7/7] Update protein_lm/modeling/models/apt/model_pytorch.py Co-authored-by: othertea <124535597+othertea@users.noreply.github.com> --- protein_lm/modeling/models/apt/model_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index 6634b68..36afbcb 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -697,7 +697,7 @@ def _init_weights(self, module): if hasattr(module, "q_attn"): # cross attention case if self.config.query_zero_init: - # q_attn same as last third of c_attn in no cross attention case -- zero initialization + # q_attn same as first third of c_attn in no cross attention case -- zero initialization self.q_attn.weight.data = 0 else: if self.config.query_zero_init: