Skip to content

Commit

Permalink
calculates seq_len based on processor hidden resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
cathalobrien committed Oct 18, 2024
1 parent 4230fb8 commit 61eadb0
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 13 deletions.
20 changes: 7 additions & 13 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence
from anemoi.models.layers.utils import calculate_seq_len

LOGGER = logging.getLogger(__name__)

Expand All @@ -48,6 +49,7 @@ def __init__(
self,
num_heads: int,
embed_dim: int,
resolution: str,
bias: bool = False,
is_causal: bool = False,
window_size: Optional[int] = None,
Expand All @@ -68,6 +70,8 @@ def __init__(

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

self.resolution = resolution

if _FLEX_ATTENTION_AVAILABLE and (os.environ.get("FLEX_ATTN", "") != "" ):
LOGGER.info("Using Flex attn")
Expand All @@ -76,19 +80,9 @@ def __init__(
def sliding_window(b, h, q_idx, kv_idx):
return abs(q_idx - kv_idx) <= window_size

#TODO stop hardcoding latent space dims
seq_len=40320 #o96
#seq_len=5248 #o32

def calculate_seq_len(grid_type: str, num_lat_lines: int):
accum=0
for i in range(1,num_lat_lines):
accum += (4 * i) + 16
return accum * 2

LOGGER.info(f"{calculate_seq_len('o', 32)}")


seq_len=calculate_seq_len(resolution=self.resolution)
LOGGER.debug(f"grid points = {seq_len} for {self.resolution} resolution")

# B and H can be None here because they are uniform, so the block mask can just be broadcast to these dims
#TODO check if B != 1, does it have to be set?
self.block_mask = create_block_mask(sliding_window, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len,_compile=True)
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
num_heads: int,
activation: str,
window_size: int,
resolution: str,
dropout_p: float = 0.0,
):
super().__init__()
Expand All @@ -82,6 +83,7 @@ def __init__(
self.attention = MultiHeadSelfAttention(
num_heads=num_heads,
embed_dim=num_channels,
resolution=resolution,
window_size=window_size,
bias=False,
is_causal=False,
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
num_channels: int,
num_layers: int,
window_size: int,
resolution: str,
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
num_heads=num_heads,
activation=activation,
window_size=window_size,
resolution=resolution,
dropout_p=dropout_p,
)

Expand Down
3 changes: 3 additions & 0 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
resolution: str = "X0",
**kwargs,
) -> None:
"""Initialize TransformerProcessor.
Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(
cpu_offload=cpu_offload,
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
resolution=resolution,
)

self.build_layers(
Expand All @@ -137,6 +139,7 @@ def __init__(
window_size=window_size,
activation=activation,
dropout_p=dropout_p,
resolution=resolution,
)

self.offload_layers(cpu_offload)
Expand Down
15 changes: 15 additions & 0 deletions src/anemoi/models/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,18 @@ def forward(self, x: Tensor) -> Tensor:
precision.
"""
return super().forward(x).type_as(x)

#takes a resolution and calculates the number of grid points
# e.g. o32 -> 5248, o96 -> 40320
def calculate_seq_len(resolution: str):
grid_type=resolution[0] # e.g. 'o'
num_lat_lines=int(resolution[1:]) # e.g. 32
accum=0
if (grid_type.lower() == 'o'):
# algorithm from https://confluence.ecmwf.int/display/FCST/Introducing+the+octahedral+reduced+Gaussian+grid
for i in range(1,num_lat_lines+1):
accum += (4 * i) + 16
result = accum * 2 # above was just pole -> equator, double for whole globe
else:
ValueError("Only octahedral (reduced) Gaussian grid, 'o', implemented")
return result

0 comments on commit 61eadb0

Please sign in to comment.