From 845bfc7158c12e157a920115d20f42867c3b2b09 Mon Sep 17 00:00:00 2001 From: Cathal OBrien Date: Fri, 18 Oct 2024 14:26:25 +0000 Subject: [PATCH] window can be optional --- src/anemoi/models/layers/attention.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/anemoi/models/layers/attention.py b/src/anemoi/models/layers/attention.py index 502c236..b5f78df 100644 --- a/src/anemoi/models/layers/attention.py +++ b/src/anemoi/models/layers/attention.py @@ -49,11 +49,11 @@ def __init__( self, num_heads: int, embed_dim: int, - resolution: str, bias: bool = False, is_causal: bool = False, window_size: Optional[int] = None, dropout_p: float = 0.0, + resolution: str = 'X0', ): super().__init__() @@ -77,16 +77,19 @@ def __init__( LOGGER.info("Using Flex attn") #LOGGER.info(f"self.num_heads {self.num_heads} self.embed_dim {self.embed_dim} self.head_dim {self.head_dim} self.dropout {self.dropout_p}") - def sliding_window(b, h, q_idx, kv_idx): - return abs(q_idx - kv_idx) <= window_size - - seq_len=calculate_seq_len(resolution=self.resolution) - LOGGER.debug(f"grid points = {seq_len} for {self.resolution} resolution") - - # B and H can be None here because they are uniform, so the block mask can just be broadcast to these dims - #TODO check if B != 1, does it have to be set? - self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True) - self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post) + if window_size != None: + def sliding_window(b, h, q_idx, kv_idx): + return abs(q_idx - kv_idx) <= window_size + + seq_len=calculate_seq_len(resolution=self.resolution) + LOGGER.debug(f"grid points = {seq_len} for {self.resolution} resolution") + + # B and H can be None here because they are uniform, so the block mask can just be broadcast to these dims + #TODO check if B != 1, does it have to be set? + self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True) + self.attention = functools.partial(flex_attention, block_mask=self.block_mask) #Cache the block mask (attn blog post) + else: + self.attention = flex_attention self.attention = compile(self.attention) #Must be compiled, otherwise entire seq_len^2 aray is materilised in memory -> OOM if (self.is_causal):