Skip to content

Commit

Permalink
temporary fix for chunking (#6)
Browse files Browse the repository at this point in the history
Co-authored-by: Jesper Dramsch <[email protected]>
  • Loading branch information
ssmmnn11 and JesperDramsch authored Jun 20, 2024
1 parent 5628c90 commit 4eacdd7
Showing 1 changed file with 1 addition and 22 deletions.
23 changes: 1 addition & 22 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,28 +490,7 @@ def forward(
), "Only batch size of 1 is supported when model is sharded across GPUs"

query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group)

# TODO: remove magic number
num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference

if num_chunks > 1:
edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1)
edge_attr_list = torch.tensor_split(edges, num_chunks, dim=0)
for i in range(num_chunks):
out1 = self.conv(
query=query,
key=key,
value=value,
edge_attr=edge_attr_list[i],
edge_index=edge_index_list[i],
size=size,
)
if i == 0:
out = torch.zeros_like(out1)
out = out + out1
else:
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)

out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)
out = self.shard_output_seq(out, shapes, batch_size, model_comm_group)
out = self.projection(out + x_r)

Expand Down

0 comments on commit 4eacdd7

Please sign in to comment.