Skip to content

Commit

Permalink
Apply post_processors before plotting in LongRolloutPlots
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie committed Oct 28, 2024
1 parent 42b59e5 commit 30dfd45
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _plot(
)

# prepare input tensor for plotting
input_tensor_0 = batch[
input_tensor_0 = pl_module.model.pre_processors(batch, in_place=False)[
self.sample_idx,
pl_module.multi_step - 1,
...,
Expand All @@ -355,7 +355,12 @@ def _plot(
# start rollout
with torch.no_grad():
for rollout_step, (_, _, y_pred) in enumerate(
pl_module.rollout_step(batch, rollout=max(self.rollout), validation_mode=False, training_mode=False),
pl_module.rollout_step(
batch,
rollout=max(self.rollout),
validation_mode=False,
training_mode=False,
),
):

if (rollout_step + 1) in self.rollout:
Expand Down

0 comments on commit 30dfd45

Please sign in to comment.