Skip to content

Commit

Permalink
Merge branch 'cherry-pick-25b2723a' into 'core_r0.4.0'
Browse files Browse the repository at this point in the history
Merge branch 'fix_bert' into 'main'

See merge request ADLR/megatron-lm!1006
  • Loading branch information
jaredcasper committed Dec 11, 2023
2 parents 8ce8065 + f087652 commit 38879f8
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions megatron/core/models/bert/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import get_linear_layer
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids


class BertModel(LanguageModule):
Expand Down Expand Up @@ -126,6 +125,40 @@ def __init__(
if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
self.initialize_last_stage_with_word_embeddings()

def bert_extended_attention_mask(self, attention_mask: Tensor) -> Tensor:
"""Creates the extended attention mask
Converts the attention mask of dimension [batch size, 1, seq len] to [batch size, 1, seq len, seq len] and makes it binary
Args:
attention_mask (Tensor): The input attention mask
Returns:
Tensor: The extended binary attention mask
"""
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)

# Convert attention mask to binary:
extended_attention_mask = extended_attention_mask < 0.5

return extended_attention_mask

def bert_position_ids(self, token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)

return position_ids

def set_input_tensor(self, input_tensor: Tensor) -> None:
"""Sets input tensor to the model.
Expand Down Expand Up @@ -158,9 +191,9 @@ def forward(
It either returns the Loss values if labels are given or the final hidden units
"""
extended_attention_mask = bert_extended_attention_mask(attention_mask)
extended_attention_mask = self.bert_extended_attention_mask(attention_mask)

position_ids = bert_position_ids(input_ids)
position_ids = self.bert_position_ids(input_ids)

# Encoder embedding.
if self.pre_process:
Expand Down

0 comments on commit 38879f8

Please sign in to comment.