Skip to content

Commit

Permalink
fix for transformer conv inference and doc string
Browse files Browse the repository at this point in the history
  • Loading branch information
ssmmnn11 committed Jun 8, 2024
1 parent f6427a3 commit 5faeafa
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 25 deletions.
23 changes: 1 addition & 22 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,28 +600,7 @@ def forward(
), "Only batch size of 1 is supported when model is sharded across GPUs"

query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group)

# TODO: Is this alright?
num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference

if num_chunks > 1:
edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1)
edge_attr_list = torch.tensor_split(edges, num_chunks, dim=0)
for i in range(num_chunks):
out1 = self.conv(
query=query,
key=key,
value=value,
edge_attr=edge_attr_list[i],
edge_index=edge_index_list[i],
size=size,
)
if i == 0:
out = torch.zeros_like(out1)
out = out + out1
else:
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)

out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)
out = self.shard_output_seq(out, shapes, batch_size, model_comm_group)
out = self.projection(out + x_r)

Expand Down
7 changes: 4 additions & 3 deletions src/anemoi/models/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# nor does it submit to any jurisdiction.
#

import math
from typing import Optional

import torch
Expand Down Expand Up @@ -77,7 +76,9 @@ def aggregate(self, edges_new: Tensor, edge_index: Adj, dim_size: Optional[int]


class GraphTransformerConv(MessagePassing):
"""Message passing part of graph transformer operator."""
"""Message passing part of graph transformer operator from 'Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification'
(https://arxiv.org/abs/2009.03509)
"""

def __init__(
self,
Expand Down Expand Up @@ -130,7 +131,7 @@ def message(
if edge_attr is not None:
key_j = key_j + edge_attr

alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = (query_i * key_j).sum(dim=-1) / self.out_channels ** (1.0 / 2.0)

alpha = softmax(alpha, index, ptr, size_i)
alpha = dropout(alpha, p=self.dropout, training=self.training)
Expand Down

0 comments on commit 5faeafa

Please sign in to comment.