From 92a4b0628429f15ee3511d1446ff42ed82e74d61 Mon Sep 17 00:00:00 2001 From: Ewan <131677160+da-ewanp@users.noreply.github.com> Date: Mon, 9 Sep 2024 12:09:51 +0100 Subject: [PATCH] Feat: adding in ability to configure precip like plots (#49) * adding in ability to configure precip like plots * updating changelog * updating changelog * remove default arg for precip_like_vars * addressing comments in pr * using list type checking * using list type checking * updating config to list for precip_and_related_fields --- CHANGELOG.md | 1 + .../config/diagnostics/eval_rollout.yaml | 1 + .../diagnostics/callbacks/__init__.py | 6 +++ src/anemoi/training/diagnostics/plots.py | 44 +++++++++++++++++-- 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 48420a3b..3c7072e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,7 @@ Keep it human-readable, your future self will thank you! - Correct errors in callback plots - fix error in the default config - example slurm config +- ability to configure precip-type plots ### Changed diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/eval_rollout.yaml index 4b746fbc..43fe06ce 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/eval_rollout.yaml @@ -24,6 +24,7 @@ plot: #Defining the accumulation levels for precipitation related fields and the colormap accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: [tp, cp] # Histogram and Spectrum plots parameters_histogram: - z_500 diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 5c0d8b0e..290bdfa6 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -510,6 +510,8 @@ def __init__(self, config: OmegaConf) -> None: """ super().__init__(config) self.sample_idx = self.config.diagnostics.plot.sample_idx + self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields + LOGGER.info(f"Using defined accumulation colormap for fields: {self.precip_and_related_fields}") @rank_zero_only def _plot( @@ -563,6 +565,7 @@ def _plot( data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), output_tensor[rollout_step, ...], + precip_and_related_fields=self.precip_and_related_fields, ) self._output_figure( @@ -605,6 +608,8 @@ def __init__(self, config: OmegaConf) -> None: """ super().__init__(config) self.sample_idx = self.config.diagnostics.plot.sample_idx + self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields + LOGGER.info(f"Using precip histogram plotting method for fields: {self.precip_and_related_fields}") @rank_zero_only def _plot( @@ -658,6 +663,7 @@ def _plot( data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), output_tensor[rollout_step, ...], + precip_and_related_fields=self.precip_and_related_fields, ) self._output_figure( diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 8ca53159..8140da4b 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -248,6 +248,7 @@ def plot_histogram( x: np.ndarray, y_true: np.ndarray, y_pred: np.ndarray, + precip_and_related_fields: list | None = None, ) -> Figure: """Plots histogram. @@ -264,6 +265,8 @@ def plot_histogram( Expected data of shape (lat*lon, nvar*level) y_pred : np.ndarray Predicted data of shape (lat*lon, nvar*level) + precip_and_related_fields : list, optional + List of precipitation-like variables, by default [] Returns ------- @@ -271,6 +274,8 @@ def plot_histogram( The figure object handle. """ + precip_and_related_fields = precip_and_related_fields or [] + n_plots_x, n_plots_y = len(parameters), 1 figsize = (n_plots_y * 4, n_plots_x * 3) @@ -300,7 +305,7 @@ def plot_histogram( hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, range=[bin_min, bin_max]) # Visualization trick for tp - if variable_name in {"tp", "cp"}: + if variable_name in precip_and_related_fields: # in-place multiplication does not work here because variables are different numpy types hist_yt = hist_yt * bins_yt[:-1] hist_yp = hist_yp * bins_yp[:-1] @@ -327,6 +332,7 @@ def plot_predicted_multilevel_flat_sample( x: np.ndarray, y_true: np.ndarray, y_pred: np.ndarray, + precip_and_related_fields: list | None = None, ) -> Figure: """Plots data for one multilevel latlon-"flat" sample. @@ -351,6 +357,8 @@ def plot_predicted_multilevel_flat_sample( Expected data of shape (lat*lon, nvar*level) y_pred : np.ndarray Predicted data of shape (lat*lon, nvar*level) + precip_and_related_fields : list, optional + List of precipitation-like variables, by default [] Returns ------- @@ -372,9 +380,33 @@ def plot_predicted_multilevel_flat_sample( yt = y_true[..., variable_idx].squeeze() yp = y_pred[..., variable_idx].squeeze() if n_plots_x > 1: - plot_flat_sample(fig, ax[plot_idx, :], pc_lon, pc_lat, xt, yt, yp, variable_name, clevels, cmap_precip) + plot_flat_sample( + fig, + ax[plot_idx, :], + pc_lon, + pc_lat, + xt, + yt, + yp, + variable_name, + clevels, + cmap_precip, + precip_and_related_fields, + ) else: - plot_flat_sample(fig, ax, pc_lon, pc_lat, xt, yt, yp, variable_name, clevels, cmap_precip) + plot_flat_sample( + fig, + ax, + pc_lon, + pc_lat, + xt, + yt, + yp, + variable_name, + clevels, + cmap_precip, + precip_and_related_fields, + ) return fig @@ -390,6 +422,7 @@ def plot_flat_sample( vname: str, clevels: float, cmap_precip: str, + precip_and_related_fields: list | None = None, ) -> None: """Plot a "flat" 1D sample. @@ -417,9 +450,12 @@ def plot_flat_sample( Accumulation levels used for precipitation related plots cmap_precip: str Colors used for each accumulation level + precip_and_related_fields : list, optional + List of precipitation-like variables, by default [] """ - if vname in {"tp", "cp"}: + precip_and_related_fields = precip_and_related_fields or [] + if vname in precip_and_related_fields: # Create a custom colormap for precipitation nws_precip_colors = cmap_precip precip_colormap = ListedColormap(nws_precip_colors)