Skip to content

Commit

Permalink
Fix support for timesteps list in case model has rain but radar does …
Browse files Browse the repository at this point in the history
…not (#411)

* Fix support for timesteps list in case model has rain but radar does not

* One more small fix and update tests

* black

* remove accidentally added duplicate test cases
  • Loading branch information
mats-knmi authored Jul 25, 2024
1 parent 07a5aa8 commit 8bec82e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 36 deletions.
48 changes: 19 additions & 29 deletions pysteps/blending/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,10 @@
from scipy.linalg import inv
from scipy.ndimage import binary_dilation, generate_binary_structure, iterate_structure

from pysteps import cascade
from pysteps import extrapolation
from pysteps import noise
from pysteps import utils
from pysteps import blending, cascade, extrapolation, noise, utils
from pysteps.nowcasts import utils as nowcast_utils
from pysteps.postprocessing import probmatching
from pysteps.timeseries import autoregression, correlation
from pysteps import blending

try:
import dask
Expand Down Expand Up @@ -578,6 +574,14 @@ def forecast(
precip_models_pm, precip_thr, norain_thr
)

if isinstance(timesteps, int):
timesteps = list(range(timesteps + 1))
timestep_type = "int"
else:
original_timesteps = [0] + list(timesteps)
timesteps = nowcast_utils.binned_timesteps(original_timesteps)
timestep_type = "list"

# 2.3.1 If precip is below the norain threshold and precip_models_pm is zero,
# we consider it as no rain in the domain.
# The forecast will directly return an array filled with the minimum
Expand All @@ -591,14 +595,6 @@ def forecast(
# Create the output list
R_f = [[] for j in range(n_ens_members)]

if isinstance(timesteps, int):
timesteps = range(timesteps + 1)
timestep_type = "int"
else:
original_timesteps = [0] + list(timesteps)
timesteps = nowcast_utils.binned_timesteps(original_timesteps)
timestep_type = "list"

# Save per time step to ensure the array does not become too large if
# no return_output is requested and callback is not None.
for t, subtimestep_idx in enumerate(timesteps):
Expand All @@ -610,12 +606,13 @@ def forecast(
R_f_ = np.full(
(n_ens_members, precip_shape[0], precip_shape[1]), np.nanmin(precip)
)
if callback is not None:
if R_f_.shape[1] > 0:
callback(R_f_.squeeze())
if return_output:
for j in range(n_ens_members):
R_f[j].append(R_f_[j])
if subtimestep_idx:
if callback is not None:
if R_f_.shape[1] > 0:
callback(R_f_.squeeze())
if return_output:
for j in range(n_ens_members):
R_f[j].append(R_f_[j])

R_f_ = None

Expand Down Expand Up @@ -680,7 +677,8 @@ def forecast(
precip_models_pm, precip_thr, precip_models_pm.shape[0], timesteps
)
# Make sure precip_noise_input is three dimensional
precip_noise_input = precip_noise_input[np.newaxis, :, :]
if len(precip_noise_input.shape) != 3:
precip_noise_input = precip_noise_input[np.newaxis, :, :]
else:
precip_noise_input = precip.copy()

Expand Down Expand Up @@ -782,14 +780,6 @@ def forecast(
if measure_time:
starttime_mainloop = time.time()

if isinstance(timesteps, int):
timesteps = range(timesteps + 1)
timestep_type = "int"
else:
original_timesteps = [0] + list(timesteps)
timesteps = nowcast_utils.binned_timesteps(original_timesteps)
timestep_type = "list"

extrap_kwargs["return_displacement"] = True
forecast_prev = precip_cascade
noise_prev = noise_cascade
Expand Down Expand Up @@ -2498,7 +2488,7 @@ def _determine_max_nr_rainy_cells_nwp(
max_rain_pixels_j = -1
max_rain_pixels_t = -1
for j in range(n_models):
for t in range(timesteps):
for t in timesteps:
rain_pixels = precip_models_pm[j][t][
precip_models_pm[j][t] > precip_thr
].size
Expand Down
25 changes: 18 additions & 7 deletions pysteps/tests/test_blending_steps.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-

import numpy as np
import datetime

import numpy as np
import pytest
import pysteps
from pysteps import cascade, blending

import pysteps
from pysteps import blending, cascade

steps_arg_values = [
(1, 3, 4, 8, None, None, False, "spn", True, 4, False, False, 0, False),
Expand All @@ -14,6 +15,7 @@
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, False),
(1, 3, 4, 8, None, "mean", False, "spn", True, 4, False, False, 0, True),
(1, 3, 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False),
(1, [1, 2, 3], 4, 8, None, "cdf", False, "spn", True, 4, False, False, 0, False),
(1, 3, 4, 8, "incremental", "cdf", False, "spn", True, 4, False, False, 0, False),
(1, 3, 4, 6, "incremental", "cdf", False, "bps", True, 4, False, False, 0, False),
(1, 3, 4, 6, "incremental", "cdf", False, "bps", False, 4, False, False, 0, False),
Expand Down Expand Up @@ -42,11 +44,13 @@
(1, 3, 6, 8, None, None, False, "spn", True, 6, False, True, 80, False),
(5, 3, 5, 6, "incremental", "cdf", False, "spn", False, 5, True, False, 80, True),
(5, 3, 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
(5, [1, 2, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
(5, [1, 3], 5, 6, "obs", "mean", False, "spn", False, 5, True, True, 80, False),
]

steps_arg_names = (
"n_models",
"n_timesteps",
"timesteps",
"n_ens_members",
"n_cascade_levels",
"mask_method",
Expand All @@ -65,7 +69,7 @@
@pytest.mark.parametrize(steps_arg_names, steps_arg_values)
def test_steps_blending(
n_models,
n_timesteps,
timesteps,
n_ens_members,
n_cascade_levels,
mask_method,
Expand All @@ -85,7 +89,14 @@ def test_steps_blending(
# The input data
###
# Initialise dummy NWP data
nwp_precip = np.zeros((n_models, n_timesteps + 1, 200, 200))
if not isinstance(timesteps, int):
n_timesteps = len(timesteps)
last_timestep = timesteps[-1]
else:
n_timesteps = timesteps
last_timestep = timesteps

nwp_precip = np.zeros((n_models, last_timestep + 1, 200, 200))

if not zero_nwp:
for n_model in range(n_models):
Expand Down Expand Up @@ -250,7 +261,7 @@ def test_steps_blending(
precip_models=nwp_precip_decomp,
velocity=radar_velocity,
velocity_models=nwp_velocity,
timesteps=n_timesteps,
timesteps=timesteps,
timestep=5.0,
issuetime=datetime.datetime.strptime("202112012355", "%Y%m%d%H%M"),
n_ens_members=n_ens_members,
Expand Down

0 comments on commit 8bec82e

Please sign in to comment.