Skip to content

Commit

Permalink
Re-added asserts in mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Oct 17, 2024
1 parent 4aebd8f commit 53a23d4
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def pre_process(self, x, shard_shapes, model_comm_group=None):
return x_src, x_dst, shapes_src, shapes_dst



class GraphEdgeMixin:
def _register_edges(
self, sub_graph: HeteroData, edge_attributes: list[str], src_size: int, dst_size: int, trainable_size: int
Expand All @@ -135,8 +134,8 @@ def _register_edges(
trainable_size : int
Trainable tensor size
"""
if edge_attributes is None:
raise ValueError("Edge attributes must be provided")
assert sub_graph, f"{self.__class__.__name__} needs a valid sub_graph to register edges."
assert edge_attributes is not None, "Edge attributes must be provided"

edge_attr_tensor = torch.cat([sub_graph[attr] for attr in edge_attributes], axis=1)

Expand Down

0 comments on commit 53a23d4

Please sign in to comment.