diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 5ed3cf61..dcef7c22 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -344,7 +344,7 @@ def _plot( ) # prepare input tensor for plotting - input_tensor_0 = batch[ + input_tensor_0 = pl_module.model.pre_processors(batch, in_place=False)[ self.sample_idx, pl_module.multi_step - 1, ..., @@ -355,7 +355,12 @@ def _plot( # start rollout with torch.no_grad(): for rollout_step, (_, _, y_pred) in enumerate( - pl_module.rollout_step(batch, rollout=max(self.rollout), validation_mode=False, training_mode=False), + pl_module.rollout_step( + batch, + rollout=max(self.rollout), + validation_mode=False, + training_mode=False, + ), ): if (rollout_step + 1) in self.rollout: