Skip to content

Commit

Permalink
add layer norm weight plus 1
Browse files Browse the repository at this point in the history
  • Loading branch information
Yejing-Lai committed Apr 18, 2024
1 parent bcedecd commit a66ff80
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
2 changes: 1 addition & 1 deletion megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from apex.normalization import MixedFusedRMSNorm as RMSNorm
else:
from .rmsnorm import RMSNorm
from torch.nn import LayerNorm
from .layer_norm_p1 import LayerNorm1P as LayerNorm

from .distributed import DistributedDataParallel
from .bert_model import BertModel
Expand Down
38 changes: 38 additions & 0 deletions megatron/model/layer_norm_p1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import math
import numbers

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn import init


class LayerNorm1P(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, apply_layernorm_1p=False):
super(LayerNorm1P, self).__init__()
self.eps = eps
self.apply_layernorm_1p = apply_layernorm_1p

if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()

def reset_parameters(self):

if self.apply_layernorm_1p:
init.zeros_(self.weight)
init.zeros_(self.bias)
else:
init.ones_(self.weight)
init.zeros_(self.bias)

def forward(self, input):
if self.apply_layernorm_1p:
weight_plus_1 = (self.weight + 1)
output = torch.nn.functional.layer_norm(input, self.normalized_shape, weight_plus_1, self.bias, self.eps)
return output
else:
return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
9 changes: 6 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,8 @@ def __init__(self, config,
else:
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
eps=config.layernorm_epsilon,
apply_layernorm_1p=args.apply_layernorm_1p)
else:
self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
# Self attention.
Expand All @@ -939,7 +940,8 @@ def __init__(self, config,
else:
self.post_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
eps=config.layernorm_epsilon,
apply_layernorm_1p=args.apply_layernorm_1p)
else:
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)
# Cross attention.
Expand Down Expand Up @@ -1762,7 +1764,8 @@ def build_layer(layer_number, n_e):
else:
self.final_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
eps=config.layernorm_epsilon,
apply_layernorm_1p=args.apply_layernorm_1p)
else:
self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)

Expand Down

0 comments on commit a66ff80

Please sign in to comment.