diff --git a/README.md b/README.md index ccc48746..ea9963be 100644 --- a/README.md +++ b/README.md @@ -882,4 +882,26 @@ model(nodes, context = encoded_neighbors, mask = node_masks, context_mask = neig } ``` +```bibtex +@misc{narang2021transformer, + title = {Do Transformer Modifications Transfer Across Implementations and Applications?}, + author = {Sharan Narang and Hyung Won Chung and Yi Tay and William Fedus and Thibault Fevry and Michael Matena and Karishma Malkan and Noah Fiedel and Noam Shazeer and Zhenzhong Lan and Yanqi Zhou and Wei Li and Nan Ding and Jake Marcus and Adam Roberts and Colin Raffel}, + year = {2021}, + eprint = {2102.11972}, + archivePrefix = {arXiv}, + primaryClass = {cs.LG} +} +``` + +```bibtex +@misc{zhang2019root, + title = {Root Mean Square Layer Normalization}, + author = {Biao Zhang and Rico Sennrich}, + year = {2019}, + eprint = {1910.07467}, + archivePrefix = {arXiv}, + primaryClass = {cs.LG} +} +``` + *solve intelligence... then use that to solve everything else.* - Demis Hassabis diff --git a/setup.py b/setup.py index fe3b80d8..d94f1a41 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '0.8.4', + version = '0.9.0', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 90a60abc..9190eafa 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -168,12 +168,24 @@ def forward(self, x, **kwargs): class ScaleNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() + self.scale = dim ** -0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) def forward(self, x): - n = torch.norm(x, dim = -1, keepdim = True).clamp(min = self.eps) - return x / n * self.g + norm = torch.norm(x, dim = -1, keepdim = True) * self.scale + return x / norm.clamp(min = self.eps) * self.g + +class RMSNorm(nn.Module): + def __init__(self, dim, eps = 1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim = -1, keepdim = True) * self.scale + return x / norm.clamp(min = self.eps) * self.g class Residual(nn.Module): def forward(self, x, residual): @@ -382,6 +394,7 @@ def __init__( cross_attend = False, only_cross = False, use_scalenorm = False, + use_rmsnorm = False, use_rezero = False, rel_pos_bias = False, rel_pos_num_buckets = 32, @@ -414,6 +427,7 @@ def __init__( self.cross_residual_attn = cross_residual_attn norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class norm_fn = partial(norm_class, dim) norm_fn = nn.Identity if use_rezero else norm_fn