From eeed50391855f8d552edc866701952f38015e415 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 27 Oct 2024 10:47:56 -0700 Subject: [PATCH] add neuTRENO from Nguyen et al. from the mitigating oversmoothing in transformers paper --- README.md | 11 +++++++++++ setup.py | 2 +- x_transformers/x_transformers.py | 25 ++++++++++++++++++++----- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 0662b076..f87e7952 100644 --- a/README.md +++ b/README.md @@ -2319,4 +2319,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17) } ``` +```bibtex +@article{Nguyen2023MitigatingOI, + title = {Mitigating Over-smoothing in Transformers via Regularized Nonlocal Functionals}, + author = {Tam Nguyen and Tan M. Nguyen and Richard G. Baraniuk}, + journal = {ArXiv}, + year = {2023}, + volume = {abs/2312.00751}, + url = {https://api.semanticscholar.org/CorpusID:264300597} +} +``` + *solve intelligence... then use that to solve everything else.* - Demis Hassabis diff --git a/setup.py b/setup.py index be304539..3e917ed3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '1.40.8', + version = '1.40.9', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 84006973..8cdd6608 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -944,6 +944,8 @@ def __init__( cope_talking_heads = False, softclamp_logits = False, logit_softclamp_value = 50., + neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751 + neutreno_alpha = 0.4, onnxable = False ): super().__init__() @@ -982,6 +984,11 @@ def __init__( self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None + # the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing + + self.neutreno_value_residual = neutreno_value_residual + self.neutreno_alpha = neutreno_alpha + # add GLU gating for aggregated values, from alphafold2 self.to_v_gate = None @@ -1244,11 +1251,15 @@ def forward( attn_bias = rel_pos(i, j) attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values - # previous values passed in - # https://arxiv.org/abs/2410.17897v1 + # if previous values passed in for residual, either invoke resformer or neutreno if exists(value_residual): - v = v + value_residual + if self.neutreno_value_residual: + diff_values = (value_residual - v) * self.neutreno_alpha + diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h) + else: + # https://arxiv.org/abs/2410.17897v1 + v = v + value_residual # attention is all we need @@ -1259,10 +1270,13 @@ def forward( prev_attn = prev_attn ) - # store the values for resformer from Zhou et al. https://arxiv.org/abs/2410.17897v1 + # store the values for resformer or Neutreno intermediates.values = v + if exists(value_residual) and self.neutreno_value_residual: + out = out + diff_values + # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients if exists(r): @@ -1365,7 +1379,7 @@ def __init__( layerscale_init_value = 0., unet_skips = False, reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1 - add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 | TODO: also add NeuTRENO from Nguyen et al. https://arxiv.org/abs/2312.00751 + add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 **kwargs ): super().__init__() @@ -1378,6 +1392,7 @@ def __init__( assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}' dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + add_value_residual |= attn_kwargs.get('neutreno_value_residual', False) self.dim = dim self.causal = causal