From 4eacdd79a25a7868b383aa32e2b7ac44693832cc Mon Sep 17 00:00:00 2001 From: Simon Lang Date: Thu, 20 Jun 2024 08:31:50 +0100 Subject: [PATCH] temporary fix for chunking (#6) Co-authored-by: Jesper Dramsch --- src/anemoi/models/layers/block.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/src/anemoi/models/layers/block.py b/src/anemoi/models/layers/block.py index f8cbb79..ba29607 100644 --- a/src/anemoi/models/layers/block.py +++ b/src/anemoi/models/layers/block.py @@ -490,28 +490,7 @@ def forward( ), "Only batch size of 1 is supported when model is sharded across GPUs" query, key, value, edges = self.shard_qkve_heads(query, key, value, edges, shapes, batch_size, model_comm_group) - - # TODO: remove magic number - num_chunks = self.num_chunks if self.training else 4 # reduce memory for inference - - if num_chunks > 1: - edge_index_list = torch.tensor_split(edge_index, num_chunks, dim=1) - edge_attr_list = torch.tensor_split(edges, num_chunks, dim=0) - for i in range(num_chunks): - out1 = self.conv( - query=query, - key=key, - value=value, - edge_attr=edge_attr_list[i], - edge_index=edge_index_list[i], - size=size, - ) - if i == 0: - out = torch.zeros_like(out1) - out = out + out1 - else: - out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) - + out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size) out = self.shard_output_seq(out, shapes, batch_size, model_comm_group) out = self.projection(out + x_r)