From 5895994c3e382b7a407fd9f4b9d7e2c7fe200150 Mon Sep 17 00:00:00 2001 From: "Lai, Yejing" Date: Wed, 17 Apr 2024 22:00:33 -0700 Subject: [PATCH] add layer norm weight plus 1 --- megatron/model/__init__.py | 1 + megatron/model/layer_norm_p1.py | 14 ++++++++++++++ megatron/model/transformer.py | 33 ++++++++++++++++++++++++--------- 3 files changed, 39 insertions(+), 9 deletions(-) create mode 100644 megatron/model/layer_norm_p1.py diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 2306749fcb..626c75fb89 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -7,6 +7,7 @@ else: from .rmsnorm import RMSNorm from torch.nn import LayerNorm + from .layer_norm_p1 import LayerNorm1P from .distributed import DistributedDataParallel from .bert_model import BertModel diff --git a/megatron/model/layer_norm_p1.py b/megatron/model/layer_norm_p1.py new file mode 100644 index 0000000000..9287f17223 --- /dev/null +++ b/megatron/model/layer_norm_p1.py @@ -0,0 +1,14 @@ +import math + +import torch +import torch.nn as nn + + +class LayerNorm1P(torch.nn.LayerNorm): + def __init__(self, *args, **kwargs): + super(LayerNorm1P, self).__init__(*args, **kwargs) + + def forward(self, input): + weight_plus_1 = (self.weight + 1) + output = torch.nn.functional.layer_norm(input, self.normalized_shape, weight_plus_1, self.bias, self.eps) + return output diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e75f13a24f..7d2f8fc981 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -911,9 +911,14 @@ def __init__(self, config, apply_layernorm_1p=args.apply_layernorm_1p, mem_efficient_ln=args.mem_efficient_ln) else: - self.input_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon) + if args.apply_layernorm_1p: + self.input_layernorm = LayerNorm1P( + config.hidden_size, + eps=config.layernorm_epsilon) + else: + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon) else: self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) # Self attention. @@ -937,9 +942,14 @@ def __init__(self, config, apply_layernorm_1p=args.apply_layernorm_1p, mem_efficient_ln=args.mem_efficient_ln) else: - self.post_attention_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon) + if args.apply_layernorm_1p: + self.input_layernorm = LayerNorm1P( + config.hidden_size, + eps=config.layernorm_epsilon) + else: + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon) else: self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) # Cross attention. @@ -1760,9 +1770,14 @@ def build_layer(layer_number, n_e): apply_layernorm_1p=args.apply_layernorm_1p, mem_efficient_ln=args.mem_efficient_ln) else: - self.final_layernorm = LayerNorm( - config.hidden_size, - eps=config.layernorm_epsilon) + if args.apply_layernorm_1p: + self.input_layernorm = LayerNorm1P( + config.hidden_size, + eps=config.layernorm_epsilon) + else: + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon) else: self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)