Skip to content

Commit

Permalink
window can be optional
Browse files Browse the repository at this point in the history
  • Loading branch information
cathalobrien committed Oct 18, 2024
1 parent 61eadb0 commit 845bfc7
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def __init__(
self,
num_heads: int,
embed_dim: int,
resolution: str,
bias: bool = False,
is_causal: bool = False,
window_size: Optional[int] = None,
dropout_p: float = 0.0,
resolution: str = 'X0',
):
super().__init__()

Expand All @@ -77,16 +77,19 @@ def __init__(
LOGGER.info("Using Flex attn")
#LOGGER.info(f"self.num_heads {self.num_heads} self.embed_dim {self.embed_dim} self.head_dim {self.head_dim} self.dropout {self.dropout_p}")

def sliding_window(b, h, q_idx, kv_idx):
return abs(q_idx - kv_idx) <= window_size

seq_len=calculate_seq_len(resolution=self.resolution)
LOGGER.debug(f"grid points = {seq_len} for {self.resolution} resolution")

# B and H can be None here because they are uniform, so the block mask can just be broadcast to these dims
#TODO check if B != 1, does it have to be set?
self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True)
self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post)
if window_size != None:
def sliding_window(b, h, q_idx, kv_idx):
return abs(q_idx - kv_idx) <= window_size

seq_len=calculate_seq_len(resolution=self.resolution)
LOGGER.debug(f"grid points = {seq_len} for {self.resolution} resolution")

# B and H can be None here because they are uniform, so the block mask can just be broadcast to these dims
#TODO check if B != 1, does it have to be set?
self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True)
self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post)
else:
self.attention = flex_attention
self.attention = compile(self.attention) #Must be compiled, otherwise entire seq_len^2 aray is materilised in memory -> OOM

if (self.is_causal):
Expand Down

0 comments on commit 845bfc7

Please sign in to comment.