diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c7072e3..7f79dc0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ Keep it human-readable, your future self will thank you! ### Added +#### Miscellaneous + +- Introduction of remapper to anemoi-models leads to changes in the data indices and some preprocessors cannot be applied in-place anymore. + #### Functionality - Enable the callback for plotting a histogram for variables containing NaNs diff --git a/src/anemoi/training/config/data/zarr.yaml b/src/anemoi/training/config/data/zarr.yaml index 27c17edb..1657861f 100644 --- a/src/anemoi/training/config/data/zarr.yaml +++ b/src/anemoi/training/config/data/zarr.yaml @@ -26,6 +26,7 @@ forcing: diagnostic: - tp - cp +remapped: normalizer: default: "mean-std" @@ -48,17 +49,23 @@ normalizer: imputer: default: "none" +remapper: + default: "none" # processors including imputers and normalizers are applied in order of definition processors: # example_imputer: - # _target_: anemoi.models.preprocessing.imputer.InputImputer - # _convert_: all - # config: ${data.imputer} + # _target_: anemoi.models.preprocessing.imputer.InputImputer + # _convert_: all + # config: ${data.imputer} normalizer: _target_: anemoi.models.preprocessing.normalizer.InputNormalizer _convert_: all config: ${data.normalizer} + # remapper: + # _target_: anemoi.models.preprocessing.remapper.Remapper + # _convert_: all + # config: ${data.remapper} # Values set in the code num_features: null # number of features in the forecast state diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 290bdfa6..72983f24 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -196,15 +196,15 @@ def _eval( batch: torch.Tensor, ) -> None: loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) - # NB! the batch is already normalized in-place - see pl_model.validation_step() metrics = {} # start rollout + batch = pl_module.model.pre_processors(batch, in_place=False) x = batch[ :, 0 : pl_module.multi_step, ..., - pl_module.data_indices.data.input.full, + pl_module.data_indices.internal_data.input.full, ] # (bs, multi_step, latlon, nvar) assert ( batch.shape[1] >= self.rollout + pl_module.multi_step @@ -217,7 +217,7 @@ def _eval( :, pl_module.multi_step + rollout_step, ..., - pl_module.data_indices.data.output.full, + pl_module.data_indices.internal_data.output.full, ] # target, shape = (bs, latlon, nvar) # y includes the auxiliary variables, so we must leave those out when computing the loss loss += pl_module.loss(y_pred, y) @@ -408,10 +408,10 @@ def automatically_determine_group(name: str) -> str: parameters_to_groups = unique_group_list[group_inverse] unique_group_list, group_inverse = np.unique(parameters_to_groups, return_inverse=True) - # sort paramters by groups + # sort parameters by groups sort_by_parameter_group = np.argsort(group_inverse, kind="stable") - # apply new order to paramters + # apply new order to parameters sorted_parameter_names = np.array(self.parameter_names)[sort_by_parameter_group] parameters_to_groups = parameters_to_groups[sort_by_parameter_group] unique_group_list, group_inverse, group_counts = np.unique( @@ -460,17 +460,19 @@ def _plot( batch_idx: int, epoch: int, ) -> None: - del batch_idx # unused logger = trainer.logger - parameter_names = list(pl_module.data_indices.model.output.name_to_index.keys()) - paramter_positions = list(pl_module.data_indices.model.output.name_to_index.values()) + parameter_names = list(pl_module.data_indices.internal_model.output.name_to_index.keys()) + parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) # reorder parameter_names by position - self.parameter_names = [parameter_names[i] for i in np.argsort(paramter_positions)] + self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] + batch = pl_module.model.pre_processors(batch, in_place=False) for rollout_step in range(pl_module.rollout): y_hat = outputs[1][rollout_step] - y_true = batch[:, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.data.output.full] + y_true = batch[ + :, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.internal_data.output.full + ] loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group @@ -542,11 +544,12 @@ def _plot( self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) local_rank = pl_module.local_rank + batch = pl_module.model.pre_processors(batch, in_place=False) input_tensor = batch[ self.sample_idx, pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, ..., - pl_module.data_indices.data.output.full, + pl_module.data_indices.internal_data.output.full, ].cpu() data = self.post_processors(input_tensor).numpy() @@ -635,12 +638,12 @@ def _plot( if self.latlons is None: self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) local_rank = pl_module.local_rank - + batch = pl_module.model.pre_processors(batch, in_place=False) input_tensor = batch[ self.sample_idx, pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, ..., - pl_module.data_indices.data.output.full, + pl_module.data_indices.internal_data.output.full, ].cpu() data = self.post_processors(input_tensor).numpy() output_tensor = self.post_processors( diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 8140da4b..b2004cf4 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -479,42 +479,90 @@ def plot_flat_sample( norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} pred err", ) - else: - scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, title=f"{vname} target") - scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, title=f"{vname} pred") + elif vname == "mwd": + cyclic_colormap = "twilight" + + def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: + """Calculate error between two arrays in degrees in range [-180, 180].""" + tmp = (array1 - array2) % 360 + return np.where(tmp > 180, tmp - 360, tmp) + + sample_shape = truth.shape + pred = np.maximum(np.zeros(sample_shape), np.minimum(360 * np.ones(sample_shape), (pred))) + scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, cmap=cyclic_colormap, title=f"{vname} target") + scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, cmap=cyclic_colormap, title=f"capped {vname} pred") + err_plot = error_plot_in_degrees(truth, pred) scatter_plot( fig, ax[3], lon=lon, lat=lat, - data=truth - pred, - cmap="bwr", - norm=TwoSlopeNorm(vcenter=0.0), - title=f"{vname} pred err", - ) - - if sum(input_) != 0: - scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, title=f"{vname} input") - scatter_plot( - fig, - ax[4], - lon=lon, - lat=lat, - data=pred - input_, + data=err_plot, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), - title=f"{vname} increment [pred - input]", + title=f"{vname} pred err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.", ) + else: + scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, title=f"{vname} target") + scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, title=f"{vname} pred") scatter_plot( fig, - ax[5], + ax[3], lon=lon, lat=lat, - data=truth - input_, + data=truth - pred, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), - title=f"{vname} persist err", + title=f"{vname} pred err", ) + + if sum(input_) != 0: + if vname == "mwd": + scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, cmap=cyclic_colormap, title=f"{vname} input") + err_plot = error_plot_in_degrees(pred, input_) + scatter_plot( + fig, + ax[4], + lon=lon, + lat=lat, + data=err_plot, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} increment [pred - input] % 360", + ) + err_plot = error_plot_in_degrees(truth, input_) + scatter_plot( + fig, + ax[5], + lon=lon, + lat=lat, + data=err_plot, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} persist err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.", + ) + else: + scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, title=f"{vname} input") + scatter_plot( + fig, + ax[4], + lon=lon, + lat=lat, + data=pred - input_, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} increment [pred - input]", + ) + scatter_plot( + fig, + ax[5], + lon=lon, + lat=lat, + data=truth - input_, + cmap="bwr", + norm=TwoSlopeNorm(vcenter=0.0), + title=f"{vname} persist err", + ) else: ax[0].axis("off") ax[4].axis("off") diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index ff3fb916..ff1acfd7 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -84,7 +84,10 @@ def __init__( self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled - self.metric_ranges, loss_scaling = self.metrics_loss_scaling(config, data_indices) + self.metric_ranges, self.metric_ranges_validation, loss_scaling = self.metrics_loss_scaling( + config, + data_indices, + ) self.loss = WeightedMSELoss(node_weights=self.loss_weights, data_variances=loss_scaling) self.metrics = WeightedMSELoss(node_weights=self.loss_weights, ignore_nans=True) @@ -127,8 +130,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @staticmethod def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> tuple[dict, torch.Tensor]: metric_ranges = defaultdict(list) + metric_ranges_validation = defaultdict(list) loss_scaling = ( - np.ones((len(data_indices.data.output.full),), dtype=np.float32) * config.training.loss_scaling.default + np.ones((len(data_indices.internal_data.output.full),), dtype=np.float32) + * config.training.loss_scaling.default ) pressure_level = instantiate(config.training.pressure_level_scaler) @@ -140,15 +145,17 @@ def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> t pressure_level.minimum, ) - for key, idx in data_indices.model.output.name_to_index.items(): + for key, idx in data_indices.internal_model.output.name_to_index.items(): # Split pressure levels on "_" separator split = key.split("_") - if len(split) > 1: + if len(split) > 1 and split[-1].isdigit(): # Create grouped metrics for pressure levels (e.g. Q, T, U, V, etc.) for logger metric_ranges[f"pl_{split[0]}"].append(idx) # Create pressure levels in loss scaling vector if split[0] in config.training.loss_scaling.pl: - loss_scaling[idx] = config.training.loss_scaling.pl[split[0]] * pressure_level.scaler(int(split[1])) + loss_scaling[idx] = config.training.loss_scaling.pl[split[0]] * pressure_level.scaler( + int(split[-1]), + ) else: LOGGER.debug("Parameter %s was not scaled.", key) else: @@ -162,7 +169,19 @@ def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> t if key in config.training.metrics: metric_ranges[key] = [idx] loss_scaling = torch.from_numpy(loss_scaling) - return metric_ranges, loss_scaling + # metric for validation, after postprocessing + for key, idx in data_indices.model.output.name_to_index.items(): + # Split pressure levels on "_" separator + split = key.split("_") + if len(split) > 1 and split[1].isdigit(): + # Create grouped metrics for pressure levels (e.g. Q, T, U, V, etc.) for logger + metric_ranges_validation[f"pl_{split[0]}"].append(idx) + else: + metric_ranges_validation[f"sfc_{key}"].append(idx) + # Create specific metrics from hydra to log in logger + if key in config.training.metrics: + metric_ranges_validation[key] = [idx] + return metric_ranges, metric_ranges_validation, loss_scaling def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: LOGGER.debug("set_model_comm_group: %s", model_comm_group) @@ -178,18 +197,18 @@ def advance_input( x = x.roll(-1, dims=1) # Get prognostic variables - x[:, -1, :, :, self.data_indices.model.input.prognostic] = y_pred[ + x[:, -1, :, :, self.data_indices.internal_model.input.prognostic] = y_pred[ ..., - self.data_indices.model.output.prognostic, + self.data_indices.internal_model.output.prognostic, ] # get new "constants" needed for time-varying fields - x[:, -1, :, :, self.data_indices.model.input.forcing] = batch[ + x[:, -1, :, :, self.data_indices.internal_model.input.forcing] = batch[ :, self.multi_step + rollout_step, :, :, - self.data_indices.data.input.forcing, + self.data_indices.internal_data.input.forcing, ] return x @@ -201,18 +220,24 @@ def _step( ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) - batch = self.model.pre_processors(batch) # normalized in-place + # for validation not normalized in-place because remappers cannot be applied in-place + batch = self.model.pre_processors(batch, in_place=not validation_mode) metrics = {} - # start rollout - x = batch[:, 0 : self.multi_step, ..., self.data_indices.data.input.full] # (bs, multi_step, latlon, nvar) + # start rollout of preprocessed batch + x = batch[ + :, + 0 : self.multi_step, + ..., + self.data_indices.internal_data.input.full, + ] # (bs, multi_step, latlon, nvar) y_preds = [] for rollout_step in range(self.rollout): # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) y_pred = self(x) - y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.data.output.full] + y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.internal_data.output.full] # y includes the auxiliary variables, so we must leave those out when computing the loss loss += checkpoint(self.loss, y_pred, y, use_reentrant=False) @@ -243,7 +268,7 @@ def calculate_val_metrics( y_preds = [] y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) - for mkey, indices in self.metric_ranges.items(): + for mkey, indices in self.metric_ranges_validation.items(): metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics( y_pred_postprocessed[..., indices], y_postprocessed[..., indices], diff --git a/tests/train/test_loss_scaling.py b/tests/train/test_loss_scaling.py index 51764d64..0e6a60b3 100644 --- a/tests/train/test_loss_scaling.py +++ b/tests/train/test_loss_scaling.py @@ -21,6 +21,11 @@ def fake_data(request: SubRequest) -> tuple[DictConfig, IndexCollection]: "data": { "forcing": ["x"], "diagnostic": ["z", "q"], + "remapped": [ + { + "d": ["cos_d", "sin_d"], + }, + ], }, "training": { "loss_scaling": { @@ -36,7 +41,7 @@ def fake_data(request: SubRequest) -> tuple[DictConfig, IndexCollection]: }, }, ) - name_to_index = {"x": 0, "y_50": 1, "y_500": 2, "y_850": 3, "z": 5, "q": 4, "other": 6} + name_to_index = {"x": 0, "y_50": 1, "y_500": 2, "y_850": 3, "z": 5, "q": 4, "other": 6, "d": 7} data_indices = IndexCollection(config=config, name_to_index=name_to_index) return config, data_indices @@ -70,6 +75,8 @@ def fake_data(request: SubRequest) -> tuple[DictConfig, IndexCollection]: 1, # q 0.1, # z 100, # other + 1, # cos_d + 1, # sin_d ], ) expected_relu_scaling = torch.Tensor( @@ -80,6 +87,8 @@ def fake_data(request: SubRequest) -> tuple[DictConfig, IndexCollection]: 1, # q 0.1, # z 100, # other + 1, # cos_d + 1, # sin_d ], ) expected_constant_scaling = torch.Tensor( @@ -90,6 +99,8 @@ def fake_data(request: SubRequest) -> tuple[DictConfig, IndexCollection]: 1, # q 0.1, # z 100, # other + 1, # cos_d + 1, # sin_d ], ) expected_polynomial_scaling = torch.Tensor( @@ -100,6 +111,8 @@ def fake_data(request: SubRequest) -> tuple[DictConfig, IndexCollection]: 1, # q 0.1, # z 100, # other + 1, # cos_d + 1, # sin_d ], ) @@ -117,7 +130,7 @@ def fake_data(request: SubRequest) -> tuple[DictConfig, IndexCollection]: def test_loss_scaling_vals(fake_data: tuple[DictConfig, IndexCollection], expected_scaling: torch.Tensor) -> None: config, data_indices = fake_data - _, loss_scaling = GraphForecaster.metrics_loss_scaling(config, data_indices) + _, _, loss_scaling = GraphForecaster.metrics_loss_scaling(config, data_indices) assert torch.allclose(loss_scaling, expected_scaling) @@ -126,9 +139,9 @@ def test_loss_scaling_vals(fake_data: tuple[DictConfig, IndexCollection], expect def test_metric_range(fake_data: tuple[DictConfig, IndexCollection]) -> None: config, data_indices = fake_data - metric_range, _ = GraphForecaster.metrics_loss_scaling(config, data_indices) + metric_range, metric_ranges_validation, _ = GraphForecaster.metrics_loss_scaling(config, data_indices) - expected_metric_range = { + expected_metric_range_validation = { "pl_y": [ data_indices.model.output.name_to_index["y_50"], data_indices.model.output.name_to_index["y_500"], @@ -138,7 +151,14 @@ def test_metric_range(fake_data: tuple[DictConfig, IndexCollection]) -> None: "sfc_q": [data_indices.model.output.name_to_index["q"]], "sfc_z": [data_indices.model.output.name_to_index["z"]], "other": [data_indices.model.output.name_to_index["other"]], + "sfc_d": [data_indices.model.output.name_to_index["d"]], "y_850": [data_indices.model.output.name_to_index["y_850"]], } + expected_metric_range = expected_metric_range_validation.copy() + del expected_metric_range["sfc_d"] + expected_metric_range["sfc_cos_d"] = [data_indices.internal_model.output.name_to_index["cos_d"]] + expected_metric_range["sfc_sin_d"] = [data_indices.internal_model.output.name_to_index["sin_d"]] + + assert dict(metric_ranges_validation) == expected_metric_range_validation assert dict(metric_range) == expected_metric_range