Skip to content

Commit

Permalink
Test dimentions completed.
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Oct 28, 2024
1 parent 689e3fe commit 32a0ac8
Showing 1 changed file with 32 additions and 4 deletions.
36 changes: 32 additions & 4 deletions src/anemoi/models/models/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def __init__(

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size

print(
input_dim,
getattr(self, f"latlons_{self._graph_hidden_names[0]}").shape[1],
self.hidden_dims[self._graph_hidden_names[0]],
)
# Encoder data -> hidden
self.encoder = instantiate(
model_config.model.encoder,
Expand Down Expand Up @@ -148,7 +153,7 @@ def __init__(
model_config.model.decoder,
in_channels_src=self.hidden_dims[src_nodes_name],
in_channels_dst=self.hidden_dims[dst_nodes_name],
hidden_dim=self.hidden_dims[dst_nodes_name],
hidden_dim=self.hidden_dims[src_nodes_name],
out_channels_dst=self.hidden_dims[dst_nodes_name],
sub_graph=self._graph_data[(src_nodes_name, "to", dst_nodes_name)],
src_grid_size=self._hidden_grid_sizes[src_nodes_name],
Expand All @@ -159,14 +164,22 @@ def __init__(
self.decoder = instantiate(
model_config.model.decoder,
in_channels_src=self.hidden_dims[self._graph_hidden_names[0]],
in_channels_dst=self.hidden_dims[self._graph_hidden_names[0]],
in_channels_dst=input_dim,
hidden_dim=self.hidden_dims[self._graph_hidden_names[0]],
out_channels_dst=self.num_output_channels,
sub_graph=self._graph_data[(self._graph_hidden_names[0], "to", self._graph_name_data)],
src_grid_size=self._hidden_grid_sizes[self._graph_hidden_names[0]],
dst_grid_size=self._data_grid_size,
)

# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
self.boundings = nn.ModuleList(
[
instantiate(cfg, name_to_index=self.data_indices.internal_model.output.name_to_index)
for cfg in getattr(model_config.model, "bounding", [])
]
)

def _define_tensor_sizes(self, config: DotDict) -> None:

# Grid sizes
Expand All @@ -186,7 +199,7 @@ def _create_trainable_attributes(self) -> None:

for hidden in self._graph_hidden_names:
self.trainable_hidden[hidden] = TrainableTensor(
trainable_size=self.hidden_dims[hidden], tensor_size=self._hidden_grid_sizes[hidden]
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_sizes[hidden]
)

def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> Tensor:
Expand Down Expand Up @@ -215,6 +228,8 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
for hidden, x_latent in x_trainable_hiddens.items():
shard_shapes_hiddens[hidden] = get_shape_shards(x_latent, 0, model_comm_group)

# print('Input: ', x_trainable_data.shape, x_trainable_hiddens[self._graph_hidden_names[0]].shape, shard_shapes_data, shard_shapes_hiddens[self._graph_hidden_names[0]])

# Run encoder
x_data_latent, curr_latent = self._run_mapper(
self.encoder,
Expand All @@ -224,6 +239,8 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
model_comm_group=model_comm_group,
)

# print('After encoding: ', x_data_latent.shape, curr_latent.shape)

# Run processor
x_encoded_latents = {}
x_skip = {}
Expand All @@ -241,6 +258,7 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
shard_shapes=shard_shapes_hiddens[src_hidden_name],
model_comm_group=model_comm_group,
)
# print(f'After level of {src_hidden_name}: ', curr_latent.shape)

# store latents for skip connections
x_skip[src_hidden_name] = curr_latent
Expand All @@ -254,6 +272,8 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
model_comm_group=model_comm_group,
)

# print(f'After downscaling of {src_hidden_name}: ', curr_latent.shape)

# Processing hidden-most level
if self.level_process:
curr_latent = self.down_level_processor[dst_hidden_name](
Expand All @@ -263,6 +283,8 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
model_comm_group=model_comm_group,
)

# print(f'After level of {dst_hidden_name}: ', curr_latent.shape)

## Upscale
for i in range(self.num_hidden - 1, 0, -1):
src_hidden_name = self._graph_hidden_names[i]
Expand All @@ -277,8 +299,10 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
model_comm_group=model_comm_group,
)

# print(f'After upscaling of {src_hidden_name}: ', curr_latent.shape)

# Add skip connections
curr_latent += x_skip[dst_hidden_name]
curr_latent = curr_latent + x_skip[dst_hidden_name]

# Processing at same level
if self.level_process:
Expand All @@ -289,6 +313,8 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
model_comm_group=model_comm_group,
)

# print(f'After level of {dst_hidden_name}: ', curr_latent.shape)

# Run decoder
x_out = self._run_mapper(
self.decoder,
Expand All @@ -298,6 +324,8 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
model_comm_group=model_comm_group,
)

# print('After decoding: ', x_out.shape)

x_out = (
einops.rearrange(
x_out,
Expand Down

0 comments on commit 32a0ac8

Please sign in to comment.