Skip to content

Commit

Permalink
add rms norm, given transformer modifications paper out of google
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 21, 2021
1 parent c40815f commit 68ce7e7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -882,4 +882,26 @@ model(nodes, context = encoded_neighbors, mask = node_masks, context_mask = neig
}
```

```bibtex
@misc{narang2021transformer,
title = {Do Transformer Modifications Transfer Across Implementations and Applications?},
author = {Sharan Narang and Hyung Won Chung and Yi Tay and William Fedus and Thibault Fevry and Michael Matena and Karishma Malkan and Noah Fiedel and Noam Shazeer and Zhenzhong Lan and Yanqi Zhou and Wei Li and Nan Ding and Jake Marcus and Adam Roberts and Colin Raffel},
year = {2021},
eprint = {2102.11972},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```

```bibtex
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.8.4',
version = '0.9.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
18 changes: 16 additions & 2 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,24 @@ def forward(self, x, **kwargs):
class ScaleNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))

def forward(self, x):
n = torch.norm(x, dim = -1, keepdim = True).clamp(min = self.eps)
return x / n * self.g
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g

class RMSNorm(nn.Module):
def __init__(self, dim, eps = 1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))

def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g

class Residual(nn.Module):
def forward(self, x, residual):
Expand Down Expand Up @@ -382,6 +394,7 @@ def __init__(
cross_attend = False,
only_cross = False,
use_scalenorm = False,
use_rmsnorm = False,
use_rezero = False,
rel_pos_bias = False,
rel_pos_num_buckets = 32,
Expand Down Expand Up @@ -414,6 +427,7 @@ def __init__(
self.cross_residual_attn = cross_residual_attn

norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
norm_class = RMSNorm if use_rmsnorm else norm_class
norm_fn = partial(norm_class, dim)

norm_fn = nn.Identity if use_rezero else norm_fn
Expand Down

0 comments on commit 68ce7e7

Please sign in to comment.