diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 2fa5a08acc4..9c1eed50658 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -40,7 +40,7 @@ is_flash_attn_2_available, logging, ) -from ...utils.import_utils import is_triton_available +from ...utils.import_utils import is_torch_flex_attn_available, is_triton_available from .configuration_modernbert import ModernBertConfig @@ -51,6 +51,14 @@ else: RotaryEmbedding = object +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + + flex_attention = torch.compile(flex_attention) + create_block_mask = torch.compile(create_block_mask) +else: + BlockMask, create_block_mask, flex_attention = object, object, object + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" @@ -450,10 +458,44 @@ def sdpa_attention_forward( return (attn_output,) +def flex_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + position_ids: Optional[torch.LongTensor], + block_mask: "BlockMask", + max_seqlen: int, + bs: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # add dummy batch dimension -> [batch_size=1, total_nnz, 3, nheads, headdim] + qkv = qkv.unsqueeze(0) + cos, sin = rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, nheads, total_nnz, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + attn_output = flex_attention( + query, + key, + value, + score_mod=None, + block_mask=block_mask, + enable_gqa=False, + scale=None, + return_lse=False, + ) + + attn_output = attn_output.squeeze(0).transpose(0, 1).contiguous() + return (attn_output.view(bs, dim),) + + MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, + "flex_attention": flex_attention_forward, } @@ -516,7 +558,7 @@ def forward( qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] - if self.config._attn_implementation == "flash_attention_2": + if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) else: qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) @@ -560,6 +602,7 @@ def forward( sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, + block_mask: Optional["BlockMask"] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: @@ -569,6 +612,7 @@ def forward( sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, + block_mask=block_mask, max_seqlen=max_seqlen, output_attentions=output_attentions, ) @@ -611,7 +655,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor @@ -802,6 +846,50 @@ def _pad_modernbert_output( return padded_inputs +def offsets_to_sequence_ids_tensor(cu_seqlens): + """ + Converts cumulative sequence length offsets into sequence IDs. + """ + counts = cu_seqlens[1:] - cu_seqlens[:-1] + return torch.repeat_interleave(torch.arange(len(counts), device=cu_seqlens.device, dtype=torch.int32), counts) + + +def create_flex_attention_mask(cu_seqlens, window_size): + """ + Creates an attention mask for FlexAttention. + + Args: + cu_seqlens (torch.Tensor): Cumulative sequence lengths + window_size (int, optional): Size of attention window for local attention + """ + sequence_ids = offsets_to_sequence_ids_tensor(cu_seqlens) + + def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): + # only allow attention within the same sequence + same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx] + + # get position within the sequence + q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]] + kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]] + + # sliding window within each sequence + in_window = (q_pos - kv_pos).abs() <= window_size + + return same_seq & in_window + + total_nnz = cu_seqlens[-1] + + block_mask = create_block_mask( + sliding_window_seq_mask_mod, + B=None, + H=None, + Q_LEN=total_nnz, + KV_LEN=total_nnz, + device=sequence_ids.device, + ) + return block_mask + + MODERNBERT_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -935,7 +1023,7 @@ def forward( attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False - if self.config._attn_implementation == "flash_attention_2": + if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if inputs_embeds is None: @@ -954,10 +1042,24 @@ def forward( attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) + if self.config._attn_implementation == "flex_attention": + position_ids = torch.arange(cu_seqlens[-1], device=device).unsqueeze(0) hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) - for encoder_layer in self.layers: + if self.config._attn_implementation == "flex_attention": + _cached_local_mask = create_flex_attention_mask(cu_seqlens, window_size=self.config.local_attention // 2) + _cached_global_mask = create_flex_attention_mask(cu_seqlens, window_size=max_seqlen) + else: + block_mask = None + + for layer_id, encoder_layer in enumerate(self.layers): + if self.config._attn_implementation == "flex_attention": + if layer_id % self.config.global_attn_every_n_layers == 0: + block_mask = _cached_global_mask + else: + block_mask = _cached_local_mask + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -969,6 +1071,7 @@ def forward( sliding_window_mask, position_ids, cu_seqlens, + block_mask, max_seqlen, output_attentions, ) @@ -979,6 +1082,7 @@ def forward( sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, + block_mask=block_mask, max_seqlen=max_seqlen, output_attentions=output_attentions, ) @@ -1110,7 +1214,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() - if self.config._attn_implementation == "flash_attention_2": + if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: if indices is None and cu_seqlens is None and max_seqlen is None: if batch_size is None and seq_len is None: if inputs_embeds is not None: @@ -1169,7 +1273,7 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.config._attn_implementation == "flash_attention_2": + if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index edfdc94346b..0913767eeb4 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -40,7 +40,7 @@ is_flash_attn_2_available, logging, ) -from ...utils.import_utils import is_triton_available +from ...utils.import_utils import is_torch_flex_attn_available, is_triton_available from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb @@ -51,6 +51,11 @@ else: RotaryEmbedding = object +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention +else: + BlockMask, create_block_mask, flex_attention = object, object, object + _CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base" _CONFIG_FOR_DOC = "ModernBertConfig" @@ -625,10 +630,88 @@ def sdpa_attention_forward( return (attn_output,) +def offsets_to_sequence_ids_tensor(cu_seqlens): + """ + Converts cumulative sequence length offsets into sequence IDs. + """ + counts = cu_seqlens[1:] - cu_seqlens[:-1] + return torch.repeat_interleave(torch.arange(len(counts), device=cu_seqlens.device, dtype=torch.int32), counts) + + +def create_flex_attention_mask(cu_seqlens, window_size): + """ + Creates an attention mask for FlexAttention. + + Args: + cu_seqlens (torch.Tensor): Cumulative sequence lengths + window_size (int, optional): Size of attention window for local attention + """ + sequence_ids = offsets_to_sequence_ids_tensor(cu_seqlens) + + def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx): + # only allow attention within the same sequence + same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx] + + # get position within the sequence + q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]] + kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]] + + # sliding window within each sequence + in_window = (q_pos - kv_pos).abs() <= window_size + + return same_seq & in_window + + total_nnz = cu_seqlens[-1] + + block_mask = create_block_mask( + sliding_window_seq_mask_mod, + B=None, + H=None, + Q_LEN=total_nnz, + KV_LEN=total_nnz, + device=sequence_ids.device, + ) + return block_mask + + +def flex_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + rotary_emb: ModernBertUnpaddedRotaryEmbedding, + position_ids: Optional[torch.LongTensor], + block_mask: "BlockMask", + max_seqlen: int, + bs: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # add dummy batch dimension -> [batch_size=1, total_nnz, 3, nheads, headdim] + qkv = qkv.unsqueeze(0) + cos, sin = rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, nheads, total_nnz, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + attn_output = flex_attention( + query, + key, + value, + score_mod=None, + block_mask=block_mask, + enable_gqa=False, + scale=None, + return_lse=False, + ) + + attn_output = attn_output.squeeze(0).transpose(0, 1).contiguous() + return (attn_output.view(bs, dim),) + + MODERNBERT_ATTENTION_FUNCTION = { "flash_attention_2": flash_attention_forward, "eager": eager_attention_forward, "sdpa": sdpa_attention_forward, + "flex_attention": flex_attention_forward, } @@ -691,7 +774,7 @@ def forward( qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] - if self.config._attn_implementation == "flash_attention_2": + if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) else: qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) @@ -735,6 +818,7 @@ def forward( sliding_window_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, cu_seqlens: Optional[torch.Tensor] = None, + block_mask: Optional["BlockMask"] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, ) -> torch.Tensor: @@ -744,6 +828,7 @@ def forward( sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, + block_mask=block_mask, max_seqlen=max_seqlen, output_attentions=output_attentions, ) @@ -786,7 +871,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = False + _supports_flex_attn = True def _init_weights(self, module: nn.Module): cutoff_factor = self.config.initializer_cutoff_factor @@ -1038,7 +1123,7 @@ def forward( attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False - if self.config._attn_implementation == "flash_attention_2": + if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if inputs_embeds is None: @@ -1057,10 +1142,24 @@ def forward( attention_mask, sliding_window_mask = self._update_attention_mask( attention_mask, output_attentions=output_attentions ) + if self.config._attn_implementation == "flex_attention": + position_ids = torch.arange(cu_seqlens[-1], device=device).unsqueeze(0) hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) - for encoder_layer in self.layers: + if self.config._attn_implementation == "flex_attention": + _cached_local_mask = create_flex_attention_mask(cu_seqlens, window_size=self.config.local_attention // 2) + _cached_global_mask = create_flex_attention_mask(cu_seqlens, window_size=max_seqlen) + else: + block_mask = None + + for layer_id, encoder_layer in enumerate(self.layers): + if self.config._attn_implementation == "flex_attention": + if layer_id % self.config.global_attn_every_n_layers == 0: + block_mask = _cached_global_mask + else: + block_mask = _cached_local_mask + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -1072,6 +1171,7 @@ def forward( sliding_window_mask, position_ids, cu_seqlens, + block_mask, max_seqlen, output_attentions, ) @@ -1082,6 +1182,7 @@ def forward( sliding_window_mask=sliding_window_mask, position_ids=position_ids, cu_seqlens=cu_seqlens, + block_mask=block_mask, max_seqlen=max_seqlen, output_attentions=output_attentions, ) @@ -1213,7 +1314,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() - if self.config._attn_implementation == "flash_attention_2": + if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: if indices is None and cu_seqlens is None and max_seqlen is None: if batch_size is None and seq_len is None: if inputs_embeds is not None: @@ -1272,7 +1373,7 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.config._attn_implementation == "flash_attention_2": + if self.config._attn_implementation in ["flash_attention_2", "flex_attention"]: with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)