From 53a23d4d61632c67cde1ed1f2f83579202e9a6f9 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 17 Oct 2024 13:08:07 +0200 Subject: [PATCH] Re-added asserts in mapper --- src/anemoi/models/layers/mapper.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/anemoi/models/layers/mapper.py b/src/anemoi/models/layers/mapper.py index 72859cf..dc5c1f0 100644 --- a/src/anemoi/models/layers/mapper.py +++ b/src/anemoi/models/layers/mapper.py @@ -115,7 +115,6 @@ def pre_process(self, x, shard_shapes, model_comm_group=None): return x_src, x_dst, shapes_src, shapes_dst - class GraphEdgeMixin: def _register_edges( self, sub_graph: HeteroData, edge_attributes: list[str], src_size: int, dst_size: int, trainable_size: int @@ -135,8 +134,8 @@ def _register_edges( trainable_size : int Trainable tensor size """ - if edge_attributes is None: - raise ValueError("Edge attributes must be provided") + assert sub_graph, f"{self.__class__.__name__} needs a valid sub_graph to register edges." + assert edge_attributes is not None, "Edge attributes must be provided" edge_attr_tensor = torch.cat([sub_graph[attr] for attr in edge_attributes], axis=1)