From 48926f633571773bda4de432112740cc3f7d5149 Mon Sep 17 00:00:00 2001 From: Ben Viggiano Date: Sat, 15 Jun 2024 11:10:00 -0700 Subject: [PATCH 1/2] Fixed tokenization issue impacting transformers version 4.40.0 --- .../datasets/hg38_char_tokenizer.py | 36 +- standalone_hyenadna.py | 637 ++++++++++++------ 2 files changed, 452 insertions(+), 221 deletions(-) diff --git a/src/dataloaders/datasets/hg38_char_tokenizer.py b/src/dataloaders/datasets/hg38_char_tokenizer.py index b60408e..dbe3643 100644 --- a/src/dataloaders/datasets/hg38_char_tokenizer.py +++ b/src/dataloaders/datasets/hg38_char_tokenizer.py @@ -4,6 +4,7 @@ CharacterTokenzier for Hugging Face Transformers. This is heavily inspired from CanineTokenizer in transformers package. """ + import json import os from pathlib import Path @@ -13,7 +14,13 @@ class CharacterTokenizer(PreTrainedTokenizer): - def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs): + def __init__( + self, + characters: Sequence[str], + model_max_length: int, + padding_side: str = "left", + **kwargs + ): """Character tokenizer for Hugging Face transformers. Args: characters (Sequence[str]): List of desired characters. Any character which @@ -41,6 +48,18 @@ def __init__(self, characters: Sequence[str], model_max_length: int, padding_sid mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False) + self._vocab_str_to_int = { + "[CLS]": 0, + "[SEP]": 1, + "[BOS]": 2, + "[MASK]": 3, + "[PAD]": 4, + "[RESERVED]": 5, + "[UNK]": 6, + **{ch: i + 7 for i, ch in enumerate(characters)}, + } + self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} + super().__init__( bos_token=bos_token, eos_token=sep_token, @@ -55,17 +74,8 @@ def __init__(self, characters: Sequence[str], model_max_length: int, padding_sid **kwargs, ) - self._vocab_str_to_int = { - "[CLS]": 0, - "[SEP]": 1, - "[BOS]": 2, - "[MASK]": 3, - "[PAD]": 4, - "[RESERVED]": 5, - "[UNK]": 6, - **{ch: i + 7 for i, ch in enumerate(characters)}, - } - self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} + def get_vocab(self) -> Dict[str, int]: + return self._vocab_str_to_int @property def vocab_size(self) -> int: @@ -146,4 +156,4 @@ def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs): cfg_file = Path(save_directory) / "tokenizer_config.json" with open(cfg_file) as f: cfg = json.load(f) - return cls.from_config(cfg) \ No newline at end of file + return cls.from_config(cfg) diff --git a/standalone_hyenadna.py b/standalone_hyenadna.py index 0d36554..26b8656 100644 --- a/standalone_hyenadna.py +++ b/standalone_hyenadna.py @@ -11,7 +11,7 @@ """ -#@title Imports +# @title Imports # for HyenaDNA specifically import torch import math @@ -39,7 +39,7 @@ """ -#@title Hyena layer +# @title Hyena layer def fftconv(u, k, D): @@ -53,8 +53,9 @@ def fftconv(u, k, D): k_f = torch.fft.rfft(k, n=fft_size) / fft_size u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) - if len(u.shape) > 3: k_f = k_f.unsqueeze(1) - y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm="forward")[..., :seqlen] out = y + u * D.unsqueeze(-1) return out.to(dtype=u.dtype) @@ -64,8 +65,9 @@ def fftconv(u, k, D): def mul_sum(q, y): return (q * y).sum(dim=1) + class OptimModule(nn.Module): - """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ + """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters""" def register(self, name, tensor, lr=None, wd=0.0): """Register a tensor with a configurable learning rate and 0 weight decay""" @@ -76,35 +78,42 @@ def register(self, name, tensor, lr=None, wd=0.0): self.register_parameter(name, nn.Parameter(tensor)) optim = {} - if lr is not None: optim["lr"] = lr - if wd is not None: optim["weight_decay"] = wd + if lr is not None: + optim["lr"] = lr + if wd is not None: + optim["weight_decay"] = wd setattr(getattr(self, name), "_optim", optim) class Sin(nn.Module): """The Sin activation function for the Hyena Filter function.""" + def __init__(self, dim, w=10, train_freq=True): super().__init__() - self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim) + self.freq = ( + nn.Parameter(w * torch.ones(1, dim)) + if train_freq + else w * torch.ones(1, dim) + ) def forward(self, x): return torch.sin(self.freq * x) class PositionalEmbedding(OptimModule): - def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float=1e-5, **kwargs): + def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float = 1e-5, **kwargs): """Complex exponential positional embeddings for Hyena filters.""" super().__init__() self.seq_len = seq_len # The time embedding fed to the filteres is normalized so that t_f = 1 - t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 + t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 if emb_dim > 1: bands = (emb_dim - 1) // 2 # To compute the right embeddings we use the "proper" linspace t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] - w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 + w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 f = torch.linspace(1e-4, bands - 1, bands)[None, None] z = torch.exp(-1j * f * w) @@ -118,6 +127,7 @@ def forward(self, L): class ExponentialModulation(OptimModule): """The window function applied to the output of the (MLP) filter function.""" + def __init__( self, d_model, @@ -125,9 +135,9 @@ def __init__( slow_decay_pct=1.5, target=1e-2, modulation_lr=0.0, - modulate: bool=True, + modulate: bool = True, shift: float = 0.05, - **kwargs + **kwargs, ): super().__init__() self.modulate = modulate @@ -146,22 +156,22 @@ def forward(self, t, x): class HyenaFilter(OptimModule): def __init__( - self, - d_model, - emb_dim=3, # dim of input to MLP, augments with positional encoding - order=16, # width of the implicit MLP - fused_fft_conv=False, - seq_len=1024, - lr=1e-3, - lr_pos_emb=1e-5, - dropout=0.0, - w=1, # frequency of periodic activations - wd=0, # weight decay of kernel parameters - bias=True, - num_inner_mlps=2, - normalized=False, - **kwargs - ): + self, + d_model, + emb_dim=3, # dim of input to MLP, augments with positional encoding + order=16, # width of the implicit MLP + fused_fft_conv=False, + seq_len=1024, + lr=1e-3, + lr_pos_emb=1e-5, + dropout=0.0, + w=1, # frequency of periodic activations + wd=0, # weight decay of kernel parameters + bias=True, + num_inner_mlps=2, + normalized=False, + **kwargs, + ): """ Implicit long filter with modulation. @@ -184,7 +194,9 @@ def __init__( act = Sin(dim=order, w=w) self.emb_dim = emb_dim - assert emb_dim % 2 != 0 and emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)" + assert ( + emb_dim % 2 != 0 and emb_dim >= 3 + ), "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)" self.seq_len = seq_len self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb) @@ -214,7 +226,8 @@ def filter(self, L, *args, **kwargs): return h def forward(self, x, L, k=None, bias=None, *args, **kwargs): - if k is None: k = self.filter(L) + if k is None: + k = self.filter(L) # Ensure compatibility with filters that return a tuple k = k[0] if type(k) is tuple else k @@ -225,15 +238,15 @@ def forward(self, x, L, k=None, bias=None, *args, **kwargs): class HyenaOperator(nn.Module): def __init__( - self, - d_model, - l_max, - order=2, - filter_order=64, - dropout=0.0, - filter_dropout=0.0, - **filter_args, - ): + self, + d_model, + l_max, + order=2, + filter_order=64, + dropout=0.0, + filter_dropout=0.0, + **filter_args, + ): r""" Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf @@ -255,11 +268,7 @@ def __init__( self.out_proj = nn.Linear(d_model, d_model) self.short_filter = nn.Conv1d( - inner_width, - inner_width, - 3, - padding=2, - groups=inner_width + inner_width, inner_width, 3, padding=2, groups=inner_width ) self.filter_fn = HyenaFilter( d_model * (order - 1), @@ -267,38 +276,40 @@ def __init__( seq_len=l_max, channels=1, dropout=filter_dropout, - **filter_args + **filter_args, ) def forward(self, u, *args, **kwargs): l = u.size(-2) l_filter = min(l, self.l_max) u = self.in_proj(u) - u = rearrange(u, 'b l d -> b d l') + u = rearrange(u, "b l d -> b d l") - uc = self.short_filter(u)[...,:l_filter] + uc = self.short_filter(u)[..., :l_filter] *x, v = uc.split(self.d_model, dim=1) k = self.filter_fn.filter(l_filter)[0] - k = rearrange(k, 'l (o d) -> o d l', o=self.order - 1) - bias = rearrange(self.filter_fn.bias, '(o d) -> o d', o=self.order - 1) + k = rearrange(k, "l (o d) -> o d l", o=self.order - 1) + bias = rearrange(self.filter_fn.bias, "(o d) -> o d", o=self.order - 1) for o, x_i in enumerate(reversed(x[1:])): v = self.dropout(v * x_i) v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o]) - y = rearrange(v * x[0], 'b d l -> b l d') + y = rearrange(v * x[0], "b d l -> b l d") y = self.out_proj(y) return y -#@title Self-Attention (alternative) + +# @title Self-Attention (alternative) """ If you'd like to try the HyenaDNA model using attention instead, you can. ie, use a regular decoder only Transformer. """ + class SelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments @@ -309,6 +320,7 @@ class SelfAttention(nn.Module): attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal @@ -328,36 +340,51 @@ def forward(self, qkv, causal=None, key_padding_mask=None): causal = self.causal if causal is None else causal q, k, v = qkv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if key_padding_mask is not None: - padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, - device=scores.device) + padding_mask = torch.full( + (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device + ) padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float - causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + causal_mask = torch.triu( + torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 + ) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) - output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) return output + class MHA(nn.Module): - """Multi-head self-attention and cross-attention - """ + """Multi-head self-attention and cross-attention""" - def __init__(self, embed_dim, num_heads, bias=True, dropout=0.0, - softmax_scale=None, causal=False, layer_idx=None, dwconv=False,return_residual=False,device=None, dtype=None) -> None: + def __init__( + self, + embed_dim, + num_heads, + bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + dwconv=False, + return_residual=False, + device=None, + dtype=None, + ) -> None: """ - return_residual: whether to return the input x along with the output. This is for - performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. """ - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.causal = causal @@ -366,23 +393,35 @@ def __init__(self, embed_dim, num_heads, bias=True, dropout=0.0, self.return_residual = return_residual self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + assert ( + self.embed_dim % num_heads == 0 + ), "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads linear_cls = nn.Linear linear_resid_cls = LinearResidual - inner_attn_cls = SelfAttention + inner_attn_cls = SelfAttention if not self.return_residual: - self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + self.Wqkv = linear_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) else: - self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + self.Wqkv = linear_resid_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) if self.dwconv: - self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, - groups=3 * embed_dim) + self.dwconv_qkv = nn.Conv1d( + 3 * embed_dim, + 3 * embed_dim, + kernel_size=3, + padding=2, + groups=3 * embed_dim, + ) - self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, - attention_dropout=dropout) + self.inner_attn = inner_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) # output projection always have the bias (for now) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) @@ -406,36 +445,50 @@ def forward(self, x, key_padding_mask=None, **kwargs): https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 """ - kwargs = ({'key_padding_mask': key_padding_mask, **kwargs}) + kwargs = {"key_padding_mask": key_padding_mask, **kwargs} if not self.return_residual: qkv = self.Wqkv(x) else: qkv, x = self.Wqkv(x) if self.dwconv: - qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], - 'b d s -> b s d').contiguous() - qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) + qkv = rearrange( + self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], + "b d s -> b s d", + ).contiguous() + qkv = rearrange( + qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim + ) context = self.inner_attn(qkv, **kwargs) - out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) + out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) return out if not self.return_residual else (out, x) -#@title MLP layer + +# @title MLP layer """ The MLP layer after the mixer layer (HyenaOperator). """ + class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, - return_residual=False, device=None, dtype=None): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation=F.gelu, + return_residual=False, + device=None, + dtype=None, + ): """ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py """ - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -450,27 +503,39 @@ def forward(self, x): y = self.fc2(y) return y if not self.return_residual else (y, x) -#@title Block layer (Hyena + MLP layers) + +# @title Block layer (Hyena + MLP layers) """ A block consists of a Mixer layer (Hyena or attention), and a MLP layer. """ + class LinearResidual(nn.Linear): - """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense. - """ + """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input), input + class Block(nn.Module): - def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, - dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0., - drop_path1=0., drop_path2=0., - return_residual=False, - residual_in_fp32=False): + def __init__( + self, + dim, + mixer_cls=None, + mlp_cls=None, + norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, + prenorm=True, + resid_dropout1=0.0, + resid_dropout2=0.0, + drop_path1=0.0, + drop_path2=0.0, + return_residual=False, + residual_in_fp32=False, + ): """ From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py For prenorm=True, this Block has a slightly different structure compared to a regular @@ -492,23 +557,24 @@ def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, self.return_residual = return_residual self.residual_in_fp32 = residual_in_fp32 if self.residual_in_fp32: - assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True' + assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" if mixer_cls is None: mixer_cls = partial(MHA, num_heads=dim // 64) if mlp_cls is None: mlp_cls = partial(Mlp, hidden_features=4 * dim) self.mixer = mixer_cls() self.dropout1 = dropout_cls(resid_dropout1) - self.drop_path1 = StochasticDepth(drop_path1, mode='row') + self.drop_path1 = StochasticDepth(drop_path1, mode="row") self.norm1 = norm_cls(dim) self.mlp = mlp_cls(dim) if not isinstance(self.mlp, nn.Identity): self.dropout2 = dropout_cls(resid_dropout2) - self.drop_path2 = StochasticDepth(drop_path2, mode='row') + self.drop_path2 = StochasticDepth(drop_path2, mode="row") self.norm2 = norm_cls(dim) - def forward(self, hidden_states, residual = None, - mixer_subset=None, mixer_kwargs=None): + def forward( + self, hidden_states, residual=None, mixer_subset=None, mixer_kwargs=None + ): r"""Pass the input through the encoder layer. Args: hidden_states: the sequence to the encoder layer (required). @@ -526,7 +592,7 @@ def forward(self, hidden_states, residual = None, if mixer_kwargs is None: mixer_kwargs = {} if mixer_subset is not None: - mixer_kwargs['mixer_subset'] = mixer_subset + mixer_kwargs["mixer_subset"] = mixer_subset hidden_states = self.mixer(hidden_states, **mixer_kwargs) if mixer_subset is not None: residual = residual[:, mixer_subset] @@ -547,30 +613,47 @@ def forward(self, hidden_states, residual = None, if self.return_residual: # mixer out is actually a pair here mixer_out, hidden_states = mixer_out - hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out)) - + hidden_states).to(dtype=self.norm1.weight.dtype)) + hidden_states = self.norm1( + (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to( + dtype=self.norm1.weight.dtype + ) + ) if not isinstance(self.mlp, nn.Identity): mlp_out = self.mlp(hidden_states) if self.return_residual: # mlp out is actually a pair here mlp_out, hidden_states = mlp_out - hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out)) - + hidden_states).to(dtype=self.norm2.weight.dtype)) + hidden_states = self.norm2( + (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.norm2.weight.dtype + ) + ) return hidden_states -def create_mixer_cls(layer=None, - attn_layer_idx=None, attn_cfg=None, layer_idx=None, - device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} + +def create_mixer_cls( + layer=None, + attn_layer_idx=None, + attn_cfg=None, + layer_idx=None, + device=None, + dtype=None, +): + factory_kwargs = {"device": device, "dtype": dtype} if attn_layer_idx is not None and layer_idx in attn_layer_idx: - causal = True if attn_cfg is None else attn_cfg.pop('causal', True) + causal = True if attn_cfg is None else attn_cfg.pop("causal", True) mha_cls = MHA - mixer_cls = partial(mha_cls, causal=causal, layer_idx=layer_idx, - **(attn_cfg if attn_cfg is not None else {}),**factory_kwargs) + mixer_cls = partial( + mha_cls, + causal=causal, + layer_idx=layer_idx, + **(attn_cfg if attn_cfg is not None else {}), + **factory_kwargs, + ) else: # mixer_cls = instantiate(registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs) @@ -578,39 +661,67 @@ def create_mixer_cls(layer=None, return mixer_cls + def create_mlp_cls(d_model, d_inner=None, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} inner_dim = d_inner if d_inner is not None else 4 * d_model - mlp_cls = partial(Mlp, hidden_features=inner_dim, - activation=partial(F.gelu, approximate='tanh'), **factory_kwargs) + mlp_cls = partial( + Mlp, + hidden_features=inner_dim, + activation=partial(F.gelu, approximate="tanh"), + **factory_kwargs, + ) return mlp_cls -def create_block(d_model, d_inner=None, - layer=None, attn_layer_idx=None, - attn_cfg=None, layer_norm_epsilon=1e-5, - resid_dropout1=0.0, resid_dropout2=0.0, residual_in_fp32=False, - layer_idx=None, - device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - mixer_cls = create_mixer_cls(layer=layer, - attn_layer_idx=attn_layer_idx, - attn_cfg=attn_cfg, layer_idx=layer_idx, - **factory_kwargs) - mlp_cls = create_mlp_cls(d_model, d_inner=d_inner, - **factory_kwargs) +def create_block( + d_model, + d_inner=None, + layer=None, + attn_layer_idx=None, + attn_cfg=None, + layer_norm_epsilon=1e-5, + resid_dropout1=0.0, + resid_dropout2=0.0, + residual_in_fp32=False, + layer_idx=None, + device=None, + dtype=None, +): + factory_kwargs = {"device": device, "dtype": dtype} + mixer_cls = create_mixer_cls( + layer=layer, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, + layer_idx=layer_idx, + **factory_kwargs, + ) + mlp_cls = create_mlp_cls(d_model, d_inner=d_inner, **factory_kwargs) norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs) - block = Block(d_model, mixer_cls, mlp_cls, norm_cls=norm_cls, - prenorm=True, resid_dropout1=resid_dropout1, resid_dropout2=resid_dropout2,residual_in_fp32=residual_in_fp32) + block = Block( + d_model, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + prenorm=True, + resid_dropout1=resid_dropout1, + resid_dropout2=resid_dropout2, + residual_in_fp32=residual_in_fp32, + ) block.layer_idx = layer_idx return block # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 -def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True, - glu_act=False): +def _init_weights( + module, + n_layer, + initializer_range=0.02, + rescale_prenorm_residual=True, + glu_act=False, +): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=initializer_range) if module.bias is not None: @@ -628,19 +739,28 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid for name, p in module.named_parameters(): if name in ["out_proj.weight", "fc2.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) + nn.init.normal_( + p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) + ) # If using GLU activation for now, we scale the std by 2 elif name in ["output_linear.0.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block if not glu_act: - nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) + nn.init.normal_( + p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) + ) else: out_features = p.shape[0] # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5 # on average. - nn.init.normal_(p[:out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2) + nn.init.normal_( + p[: out_features // 2], + mean=0.0, + std=initializer_range / math.sqrt(2 * n_layer) * 2, + ) + -#@title Backbone model (stack of blocks) +# @title Backbone model (stack of blocks) """ A backbone model consists of a stack of blocks. If you use attention, then @@ -648,35 +768,51 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid revert to doing nothing. """ + class GPT2Embeddings(nn.Module): - def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None, - word_embed_proj_dim=None, device=None, dtype=None): + def __init__( + self, + embed_dim, + vocab_size, + max_position_embeddings, + padding_idx=None, + word_embed_proj_dim=None, + device=None, + dtype=None, + ): """ - If max_position_embeddings <= 0, there's no position embeddings - If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension - the project up to embed_dim + If max_position_embeddings <= 0, there's no position embeddings + If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension + the project up to embed_dim """ - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if word_embed_proj_dim is None: - self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, - **factory_kwargs) + self.word_embeddings = nn.Embedding( + vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs + ) self.project_in = None else: - self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim, - padding_idx=padding_idx, **factory_kwargs) - self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False, - **factory_kwargs) + self.word_embeddings = nn.Embedding( + vocab_size, + word_embed_proj_dim, + padding_idx=padding_idx, + **factory_kwargs, + ) + self.project_in = nn.Linear( + word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs + ) self.max_position_embeddings = max_position_embeddings if self.max_position_embeddings > 0: - self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, - **factory_kwargs) + self.position_embeddings = nn.Embedding( + max_position_embeddings, embed_dim, **factory_kwargs + ) def forward(self, input_ids, position_ids=None): """ - input_ids: (batch, seqlen) - position_ids: (batch, seqlen) + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) """ batch_size, seqlen = input_ids.shape embeddings = self.word_embeddings(input_ids) @@ -684,44 +820,80 @@ def forward(self, input_ids, position_ids=None): embeddings = self.project_in(embeddings) if self.max_position_embeddings > 0: if position_ids is None: - position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_ids = torch.arange( + seqlen, dtype=torch.long, device=input_ids.device + ) position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings return embeddings + class LMBackbone(nn.Module): - def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, - process_group=None, layer=None, - attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, - resid_dropout: float = 0.0, embed_dropout: float = 0.1, - layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False, - device=None, dtype=None, **kwargs) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + d_model: int, + n_layer: int, + d_inner: int, + vocab_size: int, + process_group=None, + layer=None, + attn_layer_idx=None, + attn_cfg=None, + max_position_embeddings=0, + resid_dropout: float = 0.0, + embed_dropout: float = 0.1, + layer_norm_epsilon: float = 1e-5, + initializer_cfg=None, + residual_in_fp32=False, + device=None, + dtype=None, + **kwargs, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.process_group = process_group self.residual_in_fp32 = residual_in_fp32 # note max_position_embeddings is 0 for Hyena, and therefore isn't used - self.embeddings = GPT2Embeddings(d_model, vocab_size, max_position_embeddings, - **factory_kwargs) - - self.layers = nn.ModuleList([create_block( - d_model, d_inner=d_inner, - layer=layer, attn_layer_idx=attn_layer_idx, - attn_cfg=attn_cfg, layer_norm_epsilon=layer_norm_epsilon, - resid_dropout1=embed_dropout if i == 0 else resid_dropout, - resid_dropout2=resid_dropout, residual_in_fp32=residual_in_fp32,layer_idx=i, - **factory_kwargs, - ) for i in range(n_layer)]) + self.embeddings = GPT2Embeddings( + d_model, vocab_size, max_position_embeddings, **factory_kwargs + ) + + self.layers = nn.ModuleList( + [ + create_block( + d_model, + d_inner=d_inner, + layer=layer, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, + layer_norm_epsilon=layer_norm_epsilon, + resid_dropout1=embed_dropout if i == 0 else resid_dropout, + resid_dropout2=resid_dropout, + residual_in_fp32=residual_in_fp32, + layer_idx=i, + **factory_kwargs, + ) + for i in range(n_layer) + ] + ) self.drop_f = nn.Dropout(resid_dropout) self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs) - self.apply(partial(_init_weights, n_layer=n_layer, - **(initializer_cfg if initializer_cfg is not None else {}))) + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) def forward(self, input_ids, position_ids=None): - hidden_states = self.embeddings(input_ids, position_ids=position_ids,) + hidden_states = self.embeddings( + input_ids, + position_ids=position_ids, + ) residual = None for layer in self.layers: @@ -733,7 +905,8 @@ def forward(self, input_ids, position_ids=None): return hidden_states -#@title Decoder head layer + +# @title Decoder head layer """ A simple decoder head (using MLP) to predict a sequence level classification. @@ -753,7 +926,9 @@ def __init__( ): super().__init__() - self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output) + self.output_transform = ( + nn.Identity() if d_output is None else nn.Linear(d_model, d_output) + ) if l_output is None: self.l_output = None @@ -770,7 +945,7 @@ def __init__( self.use_lengths = use_lengths self.mode = mode - if mode == 'ragged': + if mode == "ragged": assert not use_lengths def forward(self, x, state=None, lengths=None, l_output=None): @@ -819,7 +994,7 @@ def restrict(x): elif self.mode == "sum": restrict = lambda x: torch.cumsum(x, dim=-2)[..., -l_output:, :] # TODO use same restrict function as pool case - elif self.mode == 'ragged': + elif self.mode == "ragged": assert lengths is not None, "lengths must be provided for ragged mode" # remove any additional padding (beyond max length of any sequence in the batch) restrict = lambda x: x[..., : max(lengths), :] @@ -853,7 +1028,8 @@ def step(self, x, state=None): # Ignore all length logic return self.output_transform(x) -#@title Model (backbone + head) + +# @title Model (backbone + head) """ Putting it all together, the model consists of a backbone model @@ -866,43 +1042,77 @@ def step(self, x, state=None): """ + class HyenaDNAModel(nn.Module): - def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, - layer=None, attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, - resid_dropout: float = 0.0, embed_dropout: float = 0.1, - layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False, - pad_vocab_size_multiple: int = 1, use_head=False, n_classes: int = 2, - device=None, dtype=None, **kwargs) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} + def __init__( + self, + d_model: int, + n_layer: int, + d_inner: int, + vocab_size: int, + layer=None, + attn_layer_idx=None, + attn_cfg=None, + max_position_embeddings=0, + resid_dropout: float = 0.0, + embed_dropout: float = 0.1, + layer_norm_epsilon: float = 1e-5, + initializer_cfg=None, + residual_in_fp32=False, + pad_vocab_size_multiple: int = 1, + use_head=False, + n_classes: int = 2, + device=None, + dtype=None, + **kwargs, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() if vocab_size % pad_vocab_size_multiple != 0: - vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + vocab_size += pad_vocab_size_multiple - ( + vocab_size % pad_vocab_size_multiple + ) self.use_head = use_head # check if layer (config) has d_model (HF code differs from main Safari code) - if 'd_model' not in layer: - layer['d_model'] = d_model + if "d_model" not in layer: + layer["d_model"] = d_model self.backbone = LMBackbone( - d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size, - layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, + d_model=d_model, + n_layer=n_layer, + d_inner=d_inner, + vocab_size=vocab_size, + layer=layer, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, max_position_embeddings=max_position_embeddings, - resid_dropout=resid_dropout, embed_dropout=embed_dropout, + resid_dropout=resid_dropout, + embed_dropout=embed_dropout, layer_norm_epsilon=layer_norm_epsilon, - initializer_cfg=initializer_cfg, residual_in_fp32=residual_in_fp32, - **factory_kwargs, **kwargs + initializer_cfg=initializer_cfg, + residual_in_fp32=residual_in_fp32, + **factory_kwargs, + **kwargs, ) # we only need a head if doing classification, otherwise we'll use the # hidden states as embeddings if self.use_head: - self.head = SequenceDecoder(d_model=d_model, d_output=n_classes, l_output=0, mode='pool') + self.head = SequenceDecoder( + d_model=d_model, d_output=n_classes, l_output=0, mode="pool" + ) # Initialize weights and apply final processing - self.apply(partial(_init_weights, n_layer=n_layer, - **(initializer_cfg if initializer_cfg is not None else {}))) + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) # if self.use_head: # self.tie_weights() @@ -910,7 +1120,9 @@ def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, # def tie_weights(self): # self.head.weight = self.backbone.embeddings.word_embeddings.weight - def forward(self, input_ids, position_ids=None, state=None): # state for the repo interface + def forward( + self, input_ids, position_ids=None, state=None + ): # state for the repo interface hidden_states = self.backbone(input_ids, position_ids=position_ids) if self.use_head: @@ -918,12 +1130,13 @@ def forward(self, input_ids, position_ids=None, state=None): # state for the rep else: return hidden_states + """# Data pipeline """ -#@title Tokenizer +# @title Tokenizer """ Just a simple character level tokenizer. @@ -935,9 +1148,14 @@ def forward(self, input_ids, position_ids=None, state=None): # state for the rep """ - class CharacterTokenizer(PreTrainedTokenizer): - def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs): + def __init__( + self, + characters: Sequence[str], + model_max_length: int, + padding_side: str = "left", + **kwargs, + ): """Character tokenizer for Hugging Face transformers. Args: characters (Sequence[str]): List of desired characters. Any character which @@ -965,6 +1183,18 @@ def __init__(self, characters: Sequence[str], model_max_length: int, padding_sid mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False) + self._vocab_str_to_int = { + "[CLS]": 0, + "[SEP]": 1, + "[BOS]": 2, + "[MASK]": 3, + "[PAD]": 4, + "[RESERVED]": 5, + "[UNK]": 6, + **{ch: i + 7 for i, ch in enumerate(characters)}, + } + self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} + super().__init__( bos_token=bos_token, eos_token=sep_token, @@ -979,17 +1209,8 @@ def __init__(self, characters: Sequence[str], model_max_length: int, padding_sid **kwargs, ) - self._vocab_str_to_int = { - "[CLS]": 0, - "[SEP]": 1, - "[BOS]": 2, - "[MASK]": 3, - "[PAD]": 4, - "[RESERVED]": 5, - "[UNK]": 6, - **{ch: i + 7 for i, ch in enumerate(characters)}, - } - self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()} + def get_vocab(self) -> Dict[str, int]: + return self._vocab_str_to_int @property def vocab_size(self) -> int: From 54556387a05372c8990ef3a8d8ec706354fced08 Mon Sep 17 00:00:00 2001 From: Ben Viggiano Date: Sat, 15 Jun 2024 11:22:44 -0700 Subject: [PATCH 2/2] Added reference to -hf versions of checkpoints to primary README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 2770324..100681a 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,9 @@ There's a few way to use these HuggingFace weights, all with different flavors: 2. [Pytorch Lighting in this repo](#loadweights) 3. [standalone](#standalone) +Thanks to friends at HuggingFace, we also have versions of these checkpoints +that can be loaded utilizing the transformers library `AutoModel` and `AutoTokenizer` classes! This makes it super easy to load HyenaDNA models to use in your own codebase. You can access our collection of these checkpoints [here!](https://huggingface.co/collections/LongSafari/hyenadna-models-654d0cbbe113b04ba5a0f638) + ## Dependencies