Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModernBERT FlexAttention #35423

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
118 changes: 111 additions & 7 deletions src/transformers/models/modernbert/modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
is_flash_attn_2_available,
logging,
)
from ...utils.import_utils import is_triton_available
from ...utils.import_utils import is_torch_flex_attn_available, is_triton_available
from .configuration_modernbert import ModernBertConfig


Expand All @@ -51,6 +51,14 @@
else:
RotaryEmbedding = object

if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention

flex_attention = torch.compile(flex_attention)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are needed to actually use the flexattention kernels but the utils/modular_model_converter.py does not allow it in the converted file.
let me know if there is a better way to do this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ArthurZucker @tomaarsen (sorry for the ping)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any ideas on how to support compiling the flex_attention function in transformers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failing CI is related to these two lines, there is no clear way how to support compile in a clean way

create_block_mask = torch.compile(create_block_mask)
else:
BlockMask, create_block_mask, flex_attention = object, object, object

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base"
Expand Down Expand Up @@ -450,10 +458,44 @@ def sdpa_attention_forward(
return (attn_output,)


def flex_attention_forward(
module: "ModernBertAttention",
qkv: torch.Tensor,
rotary_emb: ModernBertUnpaddedRotaryEmbedding,
position_ids: Optional[torch.LongTensor],
block_mask: "BlockMask",
max_seqlen: int,
bs: int,
dim: int,
**_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# add dummy batch dimension -> [batch_size=1, total_nnz, 3, nheads, headdim]
qkv = qkv.unsqueeze(0)
cos, sin = rotary_emb(qkv, position_ids=position_ids)
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
# query, key, value: [batch_size, nheads, total_nnz, head_dim]
query, key = apply_rotary_pos_emb(query, key, cos, sin)

attn_output = flex_attention(
query,
key,
value,
score_mod=None,
block_mask=block_mask,
enable_gqa=False,
scale=None,
return_lse=False,
)

attn_output = attn_output.squeeze(0).transpose(0, 1).contiguous()
return (attn_output.view(bs, dim),)


MODERNBERT_ATTENTION_FUNCTION = {
"flash_attention_2": flash_attention_forward,
"eager": eager_attention_forward,
"sdpa": sdpa_attention_forward,
"flex_attention": flex_attention_forward,
}


Expand Down Expand Up @@ -516,7 +558,7 @@ def forward(
qkv = self.Wqkv(hidden_states)

bs = hidden_states.shape[0]
if self.config._attn_implementation == "flash_attention_2":
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
else:
qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
Expand Down Expand Up @@ -560,6 +602,7 @@ def forward(
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
block_mask: Optional["BlockMask"] = None,
max_seqlen: Optional[int] = None,
output_attentions: Optional[bool] = False,
) -> torch.Tensor:
Expand All @@ -569,6 +612,7 @@ def forward(
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
block_mask=block_mask,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -611,7 +655,7 @@ class ModernBertPreTrainedModel(PreTrainedModel):
_no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = False
_supports_flex_attn = True

def _init_weights(self, module: nn.Module):
cutoff_factor = self.config.initializer_cutoff_factor
Expand Down Expand Up @@ -802,6 +846,50 @@ def _pad_modernbert_output(
return padded_inputs


def offsets_to_sequence_ids_tensor(cu_seqlens):
"""
Converts cumulative sequence length offsets into sequence IDs.
"""
counts = cu_seqlens[1:] - cu_seqlens[:-1]
return torch.repeat_interleave(torch.arange(len(counts), device=cu_seqlens.device, dtype=torch.int32), counts)


def create_flex_attention_mask(cu_seqlens, window_size):
"""
Creates an attention mask for FlexAttention.

Args:
cu_seqlens (torch.Tensor): Cumulative sequence lengths
window_size (int, optional): Size of attention window for local attention
"""
sequence_ids = offsets_to_sequence_ids_tensor(cu_seqlens)

def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx):
# only allow attention within the same sequence
same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx]

# get position within the sequence
q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]]
kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]]

# sliding window within each sequence
in_window = (q_pos - kv_pos).abs() <= window_size

return same_seq & in_window

total_nnz = cu_seqlens[-1]

block_mask = create_block_mask(
sliding_window_seq_mask_mod,
B=None,
H=None,
Q_LEN=total_nnz,
KV_LEN=total_nnz,
device=sequence_ids.device,
)
return block_mask


MODERNBERT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand Down Expand Up @@ -935,7 +1023,7 @@ def forward(
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)

repad = False
if self.config._attn_implementation == "flash_attention_2":
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
if inputs_embeds is None:
Expand All @@ -954,10 +1042,24 @@ def forward(
attention_mask, sliding_window_mask = self._update_attention_mask(
attention_mask, output_attentions=output_attentions
)
if self.config._attn_implementation == "flex_attention":
position_ids = torch.arange(cu_seqlens[-1], device=device).unsqueeze(0)

hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)

for encoder_layer in self.layers:
if self.config._attn_implementation == "flex_attention":
_cached_local_mask = create_flex_attention_mask(cu_seqlens, window_size=self.config.local_attention // 2)
_cached_global_mask = create_flex_attention_mask(cu_seqlens, window_size=max_seqlen)
else:
block_mask = None

for layer_id, encoder_layer in enumerate(self.layers):
if self.config._attn_implementation == "flex_attention":
if layer_id % self.config.global_attn_every_n_layers == 0:
block_mask = _cached_global_mask
else:
block_mask = _cached_local_mask

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

Expand All @@ -969,6 +1071,7 @@ def forward(
sliding_window_mask,
position_ids,
cu_seqlens,
block_mask,
max_seqlen,
output_attentions,
)
Expand All @@ -979,6 +1082,7 @@ def forward(
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
block_mask=block_mask,
max_seqlen=max_seqlen,
output_attentions=output_attentions,
)
Expand Down Expand Up @@ -1110,7 +1214,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self._maybe_set_compile()

if self.config._attn_implementation == "flash_attention_2":
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
if indices is None and cu_seqlens is None and max_seqlen is None:
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
Expand Down Expand Up @@ -1169,7 +1273,7 @@ def forward(
if labels is not None:
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)

if self.config._attn_implementation == "flash_attention_2":
if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]:
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)

Expand Down
Loading