Skip to content

Commit

Permalink
add neuTRENO from Nguyen et al. from the mitigating oversmoothing in …
Browse files Browse the repository at this point in the history
…transformers paper
  • Loading branch information
lucidrains committed Oct 27, 2024
1 parent ea689f0 commit eeed503
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2319,4 +2319,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```bibtex
@article{Nguyen2023MitigatingOI,
title = {Mitigating Over-smoothing in Transformers via Regularized Nonlocal Functionals},
author = {Tam Nguyen and Tan M. Nguyen and Richard G. Baraniuk},
journal = {ArXiv},
year = {2023},
volume = {abs/2312.00751},
url = {https://api.semanticscholar.org/CorpusID:264300597}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.40.8',
version = '1.40.9',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
25 changes: 20 additions & 5 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,8 @@ def __init__(
cope_talking_heads = False,
softclamp_logits = False,
logit_softclamp_value = 50.,
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
neutreno_alpha = 0.4,
onnxable = False
):
super().__init__()
Expand Down Expand Up @@ -982,6 +984,11 @@ def __init__(

self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None

# the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing

self.neutreno_value_residual = neutreno_value_residual
self.neutreno_alpha = neutreno_alpha

# add GLU gating for aggregated values, from alphafold2

self.to_v_gate = None
Expand Down Expand Up @@ -1244,11 +1251,15 @@ def forward(
attn_bias = rel_pos(i, j)
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values

# previous values passed in
# https://arxiv.org/abs/2410.17897v1
# if previous values passed in for residual, either invoke resformer or neutreno

if exists(value_residual):
v = v + value_residual
if self.neutreno_value_residual:
diff_values = (value_residual - v) * self.neutreno_alpha
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
else:
# https://arxiv.org/abs/2410.17897v1
v = v + value_residual

# attention is all we need

Expand All @@ -1259,10 +1270,13 @@ def forward(
prev_attn = prev_attn
)

# store the values for resformer from Zhou et al. https://arxiv.org/abs/2410.17897v1
# store the values for resformer or Neutreno

intermediates.values = v

if exists(value_residual) and self.neutreno_value_residual:
out = out + diff_values

# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients

if exists(r):
Expand Down Expand Up @@ -1365,7 +1379,7 @@ def __init__(
layerscale_init_value = 0.,
unet_skips = False,
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 | TODO: also add NeuTRENO from Nguyen et al. https://arxiv.org/abs/2312.00751
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
**kwargs
):
super().__init__()
Expand All @@ -1378,6 +1392,7 @@ def __init__(
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'

dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
add_value_residual |= attn_kwargs.get('neutreno_value_residual', False)

self.dim = dim
self.causal = causal
Expand Down

0 comments on commit eeed503

Please sign in to comment.