Skip to content

Commit

Permalink
Changes to data indices in anemoi models (#17)
Browse files Browse the repository at this point in the history
* feat: incorporate changes on data indices to aneoi-training

* config file update and changelog

* tests: include remapped in data config

* pre-commit

* pre-commit

* test: remapped variable when testing loss scaling and validation metrics

* corrections to plots

* config file

* typos

* fix: fix from develop

* fix: remove len calculation in dataloaders

* config: change entity back to ??? as in develop

* comments: incorate changes requested from jesper

* typo in parameter
  • Loading branch information
sahahner authored Sep 9, 2024
1 parent 92a4b06 commit 1a377de
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 56 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Keep it human-readable, your future self will thank you!

### Added

#### Miscellaneous

- Introduction of remapper to anemoi-models leads to changes in the data indices and some preprocessors cannot be applied in-place anymore.

#### Functionality

- Enable the callback for plotting a histogram for variables containing NaNs
Expand Down
13 changes: 10 additions & 3 deletions src/anemoi/training/config/data/zarr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ forcing:
diagnostic:
- tp
- cp
remapped:

normalizer:
default: "mean-std"
Expand All @@ -48,17 +49,23 @@ normalizer:

imputer:
default: "none"
remapper:
default: "none"

# processors including imputers and normalizers are applied in order of definition
processors:
# example_imputer:
# _target_: anemoi.models.preprocessing.imputer.InputImputer
# _convert_: all
# config: ${data.imputer}
# _target_: anemoi.models.preprocessing.imputer.InputImputer
# _convert_: all
# config: ${data.imputer}
normalizer:
_target_: anemoi.models.preprocessing.normalizer.InputNormalizer
_convert_: all
config: ${data.normalizer}
# remapper:
# _target_: anemoi.models.preprocessing.remapper.Remapper
# _convert_: all
# config: ${data.remapper}

# Values set in the code
num_features: null # number of features in the forecast state
29 changes: 16 additions & 13 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,15 @@ def _eval(
batch: torch.Tensor,
) -> None:
loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False)
# NB! the batch is already normalized in-place - see pl_model.validation_step()
metrics = {}

# start rollout
batch = pl_module.model.pre_processors(batch, in_place=False)
x = batch[
:,
0 : pl_module.multi_step,
...,
pl_module.data_indices.data.input.full,
pl_module.data_indices.internal_data.input.full,
] # (bs, multi_step, latlon, nvar)
assert (
batch.shape[1] >= self.rollout + pl_module.multi_step
Expand All @@ -217,7 +217,7 @@ def _eval(
:,
pl_module.multi_step + rollout_step,
...,
pl_module.data_indices.data.output.full,
pl_module.data_indices.internal_data.output.full,
] # target, shape = (bs, latlon, nvar)
# y includes the auxiliary variables, so we must leave those out when computing the loss
loss += pl_module.loss(y_pred, y)
Expand Down Expand Up @@ -408,10 +408,10 @@ def automatically_determine_group(name: str) -> str:
parameters_to_groups = unique_group_list[group_inverse]
unique_group_list, group_inverse = np.unique(parameters_to_groups, return_inverse=True)

# sort paramters by groups
# sort parameters by groups
sort_by_parameter_group = np.argsort(group_inverse, kind="stable")

# apply new order to paramters
# apply new order to parameters
sorted_parameter_names = np.array(self.parameter_names)[sort_by_parameter_group]
parameters_to_groups = parameters_to_groups[sort_by_parameter_group]
unique_group_list, group_inverse, group_counts = np.unique(
Expand Down Expand Up @@ -460,17 +460,19 @@ def _plot(
batch_idx: int,
epoch: int,
) -> None:
del batch_idx # unused
logger = trainer.logger

parameter_names = list(pl_module.data_indices.model.output.name_to_index.keys())
paramter_positions = list(pl_module.data_indices.model.output.name_to_index.values())
parameter_names = list(pl_module.data_indices.internal_model.output.name_to_index.keys())
parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values())
# reorder parameter_names by position
self.parameter_names = [parameter_names[i] for i in np.argsort(paramter_positions)]
self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)]

batch = pl_module.model.pre_processors(batch, in_place=False)
for rollout_step in range(pl_module.rollout):
y_hat = outputs[1][rollout_step]
y_true = batch[:, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.data.output.full]
y_true = batch[
:, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.internal_data.output.full
]
loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy()

sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group
Expand Down Expand Up @@ -542,11 +544,12 @@ def _plot(
self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy())
local_rank = pl_module.local_rank

batch = pl_module.model.pre_processors(batch, in_place=False)
input_tensor = batch[
self.sample_idx,
pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1,
...,
pl_module.data_indices.data.output.full,
pl_module.data_indices.internal_data.output.full,
].cpu()
data = self.post_processors(input_tensor).numpy()

Expand Down Expand Up @@ -635,12 +638,12 @@ def _plot(
if self.latlons is None:
self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy())
local_rank = pl_module.local_rank

batch = pl_module.model.pre_processors(batch, in_place=False)
input_tensor = batch[
self.sample_idx,
pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1,
...,
pl_module.data_indices.data.output.full,
pl_module.data_indices.internal_data.output.full,
].cpu()
data = self.post_processors(input_tensor).numpy()
output_tensor = self.post_processors(
Expand Down
90 changes: 69 additions & 21 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,42 +479,90 @@ def plot_flat_sample(
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} pred err",
)
else:
scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, title=f"{vname} target")
scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, title=f"{vname} pred")
elif vname == "mwd":
cyclic_colormap = "twilight"

def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
"""Calculate error between two arrays in degrees in range [-180, 180]."""
tmp = (array1 - array2) % 360
return np.where(tmp > 180, tmp - 360, tmp)

sample_shape = truth.shape
pred = np.maximum(np.zeros(sample_shape), np.minimum(360 * np.ones(sample_shape), (pred)))
scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, cmap=cyclic_colormap, title=f"{vname} target")
scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, cmap=cyclic_colormap, title=f"capped {vname} pred")
err_plot = error_plot_in_degrees(truth, pred)
scatter_plot(
fig,
ax[3],
lon=lon,
lat=lat,
data=truth - pred,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} pred err",
)

if sum(input_) != 0:
scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, title=f"{vname} input")
scatter_plot(
fig,
ax[4],
lon=lon,
lat=lat,
data=pred - input_,
data=err_plot,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} increment [pred - input]",
title=f"{vname} pred err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.",
)
else:
scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, title=f"{vname} target")
scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, title=f"{vname} pred")
scatter_plot(
fig,
ax[5],
ax[3],
lon=lon,
lat=lat,
data=truth - input_,
data=truth - pred,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} persist err",
title=f"{vname} pred err",
)

if sum(input_) != 0:
if vname == "mwd":
scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, cmap=cyclic_colormap, title=f"{vname} input")
err_plot = error_plot_in_degrees(pred, input_)
scatter_plot(
fig,
ax[4],
lon=lon,
lat=lat,
data=err_plot,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} increment [pred - input] % 360",
)
err_plot = error_plot_in_degrees(truth, input_)
scatter_plot(
fig,
ax[5],
lon=lon,
lat=lat,
data=err_plot,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} persist err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.",
)
else:
scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, title=f"{vname} input")
scatter_plot(
fig,
ax[4],
lon=lon,
lat=lat,
data=pred - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} increment [pred - input]",
)
scatter_plot(
fig,
ax[5],
lon=lon,
lat=lat,
data=truth - input_,
cmap="bwr",
norm=TwoSlopeNorm(vcenter=0.0),
title=f"{vname} persist err",
)
else:
ax[0].axis("off")
ax[4].axis("off")
Expand Down
55 changes: 40 additions & 15 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def __init__(

self.logger_enabled = config.diagnostics.log.wandb.enabled or config.diagnostics.log.mlflow.enabled

self.metric_ranges, loss_scaling = self.metrics_loss_scaling(config, data_indices)
self.metric_ranges, self.metric_ranges_validation, loss_scaling = self.metrics_loss_scaling(
config,
data_indices,
)
self.loss = WeightedMSELoss(node_weights=self.loss_weights, data_variances=loss_scaling)
self.metrics = WeightedMSELoss(node_weights=self.loss_weights, ignore_nans=True)

Expand Down Expand Up @@ -127,8 +130,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
@staticmethod
def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> tuple[dict, torch.Tensor]:
metric_ranges = defaultdict(list)
metric_ranges_validation = defaultdict(list)
loss_scaling = (
np.ones((len(data_indices.data.output.full),), dtype=np.float32) * config.training.loss_scaling.default
np.ones((len(data_indices.internal_data.output.full),), dtype=np.float32)
* config.training.loss_scaling.default
)

pressure_level = instantiate(config.training.pressure_level_scaler)
Expand All @@ -140,15 +145,17 @@ def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> t
pressure_level.minimum,
)

for key, idx in data_indices.model.output.name_to_index.items():
for key, idx in data_indices.internal_model.output.name_to_index.items():
# Split pressure levels on "_" separator
split = key.split("_")
if len(split) > 1:
if len(split) > 1 and split[-1].isdigit():
# Create grouped metrics for pressure levels (e.g. Q, T, U, V, etc.) for logger
metric_ranges[f"pl_{split[0]}"].append(idx)
# Create pressure levels in loss scaling vector
if split[0] in config.training.loss_scaling.pl:
loss_scaling[idx] = config.training.loss_scaling.pl[split[0]] * pressure_level.scaler(int(split[1]))
loss_scaling[idx] = config.training.loss_scaling.pl[split[0]] * pressure_level.scaler(
int(split[-1]),
)
else:
LOGGER.debug("Parameter %s was not scaled.", key)
else:
Expand All @@ -162,7 +169,19 @@ def metrics_loss_scaling(config: DictConfig, data_indices: IndexCollection) -> t
if key in config.training.metrics:
metric_ranges[key] = [idx]
loss_scaling = torch.from_numpy(loss_scaling)
return metric_ranges, loss_scaling
# metric for validation, after postprocessing
for key, idx in data_indices.model.output.name_to_index.items():
# Split pressure levels on "_" separator
split = key.split("_")
if len(split) > 1 and split[1].isdigit():
# Create grouped metrics for pressure levels (e.g. Q, T, U, V, etc.) for logger
metric_ranges_validation[f"pl_{split[0]}"].append(idx)
else:
metric_ranges_validation[f"sfc_{key}"].append(idx)
# Create specific metrics from hydra to log in logger
if key in config.training.metrics:
metric_ranges_validation[key] = [idx]
return metric_ranges, metric_ranges_validation, loss_scaling

def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None:
LOGGER.debug("set_model_comm_group: %s", model_comm_group)
Expand All @@ -178,18 +197,18 @@ def advance_input(
x = x.roll(-1, dims=1)

# Get prognostic variables
x[:, -1, :, :, self.data_indices.model.input.prognostic] = y_pred[
x[:, -1, :, :, self.data_indices.internal_model.input.prognostic] = y_pred[
...,
self.data_indices.model.output.prognostic,
self.data_indices.internal_model.output.prognostic,
]

# get new "constants" needed for time-varying fields
x[:, -1, :, :, self.data_indices.model.input.forcing] = batch[
x[:, -1, :, :, self.data_indices.internal_model.input.forcing] = batch[
:,
self.multi_step + rollout_step,
:,
:,
self.data_indices.data.input.forcing,
self.data_indices.internal_data.input.forcing,
]
return x

Expand All @@ -201,18 +220,24 @@ def _step(
) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]:
del batch_idx
loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False)
batch = self.model.pre_processors(batch) # normalized in-place
# for validation not normalized in-place because remappers cannot be applied in-place
batch = self.model.pre_processors(batch, in_place=not validation_mode)
metrics = {}

# start rollout
x = batch[:, 0 : self.multi_step, ..., self.data_indices.data.input.full] # (bs, multi_step, latlon, nvar)
# start rollout of preprocessed batch
x = batch[
:,
0 : self.multi_step,
...,
self.data_indices.internal_data.input.full,
] # (bs, multi_step, latlon, nvar)

y_preds = []
for rollout_step in range(self.rollout):
# prediction at rollout step rollout_step, shape = (bs, latlon, nvar)
y_pred = self(x)

y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.data.output.full]
y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.internal_data.output.full]
# y includes the auxiliary variables, so we must leave those out when computing the loss
loss += checkpoint(self.loss, y_pred, y, use_reentrant=False)

Expand Down Expand Up @@ -243,7 +268,7 @@ def calculate_val_metrics(
y_preds = []
y_postprocessed = self.model.post_processors(y, in_place=False)
y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False)
for mkey, indices in self.metric_ranges.items():
for mkey, indices in self.metric_ranges_validation.items():
metrics[f"{mkey}_{rollout_step + 1}"] = self.metrics(
y_pred_postprocessed[..., indices],
y_postprocessed[..., indices],
Expand Down
Loading

0 comments on commit 1a377de

Please sign in to comment.