Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Edits to spatial train/val/test, additional performance metrics #211

Merged
merged 24 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
de51eea
add kge for log values, and kge for upper and lower 10th of dataset
jds485 Mar 7, 2023
e30e3cc
add args to filter the data by an earliest and latest time index
jds485 Mar 8, 2023
13ee6cb
add arg to make a strict spatial partition based on the provided val …
jds485 Mar 8, 2023
e014068
removing training sites from validation data
jds485 Mar 8, 2023
f5d132f
ensuring that this spatial data filter will work if training and vali…
jds485 Mar 8, 2023
85c6917
ensure that validation sites are not in the test set
jds485 Mar 8, 2023
37fc891
add options for biweekly, monthly, and yearly timeseries summaries, a…
jds485 Mar 10, 2023
fe69f86
add metrics, edit comments
jds485 Mar 14, 2023
81dca56
add count of observations to the metrics files
jds485 Mar 14, 2023
89ec644
change groups arg to several args that describe how to group_temporal…
jds485 Mar 21, 2023
a4930cd
revert back to 10 as minimum number of observations
jds485 Mar 29, 2023
a7e35b5
change sum to mean, and change name of function arg to time_aggregati…
jds485 Mar 29, 2023
c60adc5
update comment
jds485 Mar 29, 2023
5535ba4
change variable name from sum to mean
jds485 Apr 18, 2023
4ec8543
update parameter description
jds485 Apr 18, 2023
03a8ae0
check that trn, val, and tst partitions have unique time-site ids
jds485 Apr 18, 2023
1f1d6f7
check partitions with a nan filtering
jds485 Apr 18, 2023
029a933
make the handling of site selection consistent in evaluate and prepro…
jds485 Apr 18, 2023
71d4bfd
remove sites from partitions only if sites are not provided for those…
jds485 Apr 18, 2023
4875ffc
fix line indentation
jds485 Apr 18, 2023
2b59d01
allow overriding pretraining partition check
jds485 Apr 18, 2023
7d630f8
Merge branch 'main' into jds-edits
jds485 Apr 20, 2023
6c2f18d
add small offset for log metrics
jds485 Apr 26, 2023
c2e08ad
handle case when any of trn, val, tst are not used
jds485 Apr 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 146 additions & 33 deletions river_dl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def rmse_logged(y_true, y_pred):

def nse_logged(y_true, y_pred):
"""
compute the rmse of the logged data
compute the nse of the logged data
:param y_true: [array-like] observed y_dataset values
:param y_pred: [array-like] predicted y_dataset values
:return: [float] the nse of the logged data
Expand All @@ -77,16 +77,32 @@ def nse_logged(y_true, y_pred):

def kge_eval(y_true, y_pred):
y_true, y_pred = filter_nan_preds(y_true, y_pred)
r, _ = pearsonr(y_pred, y_true)
mean_true = np.mean(y_true)
mean_pred = np.mean(y_pred)
std_true = np.std(y_true)
std_pred = np.std(y_pred)
r_component = np.square(r - 1)
std_component = np.square((std_pred / std_true) - 1)
bias_component = np.square((mean_pred / mean_true) - 1)
return 1 - np.sqrt(r_component + std_component + bias_component)
#Need to have > 1 observation to compute correlation.
#This could be < 2 due to percentile filtering
if len(y_true) > 1:
r, _ = pearsonr(y_pred, y_true)
mean_true = np.mean(y_true)
mean_pred = np.mean(y_pred)
std_true = np.std(y_true)
std_pred = np.std(y_pred)
r_component = np.square(r - 1)
std_component = np.square((std_pred / std_true) - 1)
bias_component = np.square((mean_pred / mean_true) - 1)
result = 1 - np.sqrt(r_component + std_component + bias_component)
else:
result = np.nan
return result

def kge_logged(y_true, y_pred):
"""
compute the kge of the logged data
:param y_true: [array-like] observed y_dataset values
:param y_pred: [array-like] predicted y_dataset values
:return: [float] the nse of the logged data
"""
y_true, y_pred = filter_nan_preds(y_true, y_pred)
y_true, y_pred = filter_negative_preds(y_true, y_pred)
return kge_eval(np.log(y_true), np.log(y_pred))

def filter_by_percentile(y_true, y_pred, percentile, less_than=True):
"""
Expand Down Expand Up @@ -136,7 +152,7 @@ def calc_metrics(df):
pred = df["pred"].values
obs, pred = filter_nan_preds(obs, pred)

if len(obs) > 10:
jds485 marked this conversation as resolved.
Show resolved Hide resolved
if len(obs) > 20:
metrics = {
"rmse": rmse_eval(obs, pred),
"nse": nse_eval(obs, pred),
Expand All @@ -162,12 +178,10 @@ def calc_metrics(df):
),
"nse_logged": nse_logged(obs, pred),
"kge": kge_eval(obs, pred),
"rmse_logged": rmse_logged(obs, pred),
"nse_top10": percentile_metric(obs, pred, nse_eval, 90, less_than=False),
"nse_bot10": percentile_metric(obs, pred, nse_eval, 10, less_than=True),
"nse_logged": nse_logged(obs, pred),
"kge_logged": kge_logged(obs, pred),
jds485 marked this conversation as resolved.
Show resolved Hide resolved
"kge_top10": percentile_metric(obs, pred, kge_eval, 90, less_than=False),
"kge_bot10": percentile_metric(obs, pred, kge_eval, 10, less_than=True)
}

else:
metrics = {
"rmse": np.nan,
Expand All @@ -182,10 +196,9 @@ def calc_metrics(df):
"nse_bot10": np.nan,
"nse_logged": np.nan,
"kge": np.nan,
"rmse_logged": np.nan,
"nse_top10": np.nan,
"nse_bot10": np.nan,
"nse_logged": np.nan,
"kge_logged": np.nan,
"kge_top10": np.nan,
"kge_bot10": np.nan
}
return pd.Series(metrics)

Expand Down Expand Up @@ -224,27 +237,79 @@ def partition_metrics(
:param outfile: [str] file where the metrics should be written
jds485 marked this conversation as resolved.
Show resolved Hide resolved
:param val_sites: [list] sites to exclude from training and test metrics
:param test_sites: [list] sites to exclude from validation and training metrics
:param train_sites: [list] sites to exclude from test metrics
:param train_sites: [list] sites to exclude from validation and test metrics
:return: [pd dataframe] the condensed metrics
"""
var_data = fmt_preds_obs(preds, obs_file, spatial_idx_name,
time_idx_name)
var_metrics_list = []

for data_var, data in var_data.items():
#multiindex df
data_multiind = data.copy(deep=True)
jds485 marked this conversation as resolved.
Show resolved Hide resolved
data.reset_index(inplace=True)
# mask out validation and test sites from trn partition
if val_sites and partition == 'trn':
data = data[~data[spatial_idx_name].isin(val_sites)]
if test_sites and partition == 'trn':
data = data[~data[spatial_idx_name].isin(test_sites)]
# mask out test sites from val partition
if test_sites and partition=='val':
data = data[~data[spatial_idx_name].isin(test_sites)]
if train_sites and partition=='tst':
data = data[~data[spatial_idx_name].isin(train_sites)]
if val_sites and partition=='tst':
data = data[~data[spatial_idx_name].isin(val_sites)]
if train_sites and partition == 'trn':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that there is an omission in the original code and that the following code should be added to correct that:

if train_sites and partition=='val':
    data = data[~data[spatial_idx_name].isin(train_sites)]

If I'm understanding them correctly, I think the suggested edits change the functionality of this section. As I read the original code, it appears that train_sites (as well as test_sites and val_sites) were sites that only appeared in that partition, but they weren't necessarily the only sites in that partition. In the revised code, it appears that if train_sites is specified, it will use only those sites in the training evaluations (and remove those sites from the test and validation partition).

If the intention is to change the functionality of train_sites, etc, then it's probably good to have a broader conversation. I am not using that option in my projects currently, but I don't know if others are.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I read the original code, it appears that train_sites (as well as test_sites and val_sites) were sites that only appeared in that partition, but they weren't necessarily the only sites in that partition.

Yes, I agree. I wasn't sure if this was intentional. I could add another explicit_spatial_split parameter here to allow for the previous functionality when False and this new functionality when True. I'll hold off on that before receiving feedback from others

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow. Yeah. This was definitely a bug! 😮 Thanks for catching this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janetrbarclay: in a comment below, Jeff suggests we assume that sites within the train/val/test are the only sites in those partitions. That's also what I would expect. Do you know of anyone who is relying on the previous method wherein sites that are not within train_sites/val_sites/test_sites could be in the train/val/test partitions?

# simply use the train sites when specified.
data = data[data[spatial_idx_name].isin(train_sites)]
data_multiind = data_multiind.loc[data_multiind
.index
.get_level_values(level=spatial_idx_name)
.isin(train_sites)]
else:
#check if validation or testing sites are specified
if val_sites and partition == 'trn':
data = data[~data[spatial_idx_name].isin(val_sites)]
data_multiind = data_multiind.loc[~data_multiind
.index
.get_level_values(level=spatial_idx_name)
.isin(val_sites)]
if test_sites and partition == 'trn':
data = data[~data[spatial_idx_name].isin(test_sites)]
data_multiind = data_multiind.loc[~data_multiind
.index
.get_level_values(level=spatial_idx_name)
.isin(test_sites)]
jds485 marked this conversation as resolved.
Show resolved Hide resolved
# mask out training and test sites from val partition
if val_sites and partition == 'val':
data = data[data[spatial_idx_name].isin(val_sites)]
data_multiind = data_multiind.loc[data_multiind
.index
.get_level_values(level=spatial_idx_name)
.isin(val_sites)]
else:
if test_sites and partition=='val':
data = data[~data[spatial_idx_name].isin(test_sites)]
data_multiind = data_multiind.loc[~data_multiind
.index
.get_level_values(level=spatial_idx_name)
.isin(test_sites)]
if train_sites and partition=='val':
data = data[~data[spatial_idx_name].isin(train_sites)]
data_multiind = data_multiind.loc[~data_multiind
.index
.get_level_values(level=spatial_idx_name)
.isin(train_sites)]
# mask out training and validation sites from val partition
if test_sites and partition == 'tst':
data = data[data[spatial_idx_name].isin(test_sites)]
data_multiind = data_multiind.loc[data_multiind
.index
.get_level_values(level=spatial_idx_name)
.isin(test_sites)]
else:
if train_sites and partition=='tst':
data = data[~data[spatial_idx_name].isin(train_sites)]
data_multiind = data_multiind.loc[~data_multiind
.index
.get_level_values(level=spatial_idx_name)
.isin(train_sites)]
if val_sites and partition=='tst':
data = data[~data[spatial_idx_name].isin(val_sites)]
data_multiind = data_multiind.loc[~data_multiind
.index
.get_level_values(level=spatial_idx_name)
.isin(val_sites)]

if not group:
metrics = calc_metrics(data)
Expand All @@ -268,6 +333,54 @@ def partition_metrics(
.apply(calc_metrics)
.reset_index()
)
elif group == "year":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting to be quite verbose. I suggest we split the group argument into two, maybe (group_spatially, and group_temporally). The group_spatially would be just a boolean. The group_temporally would be a str.

then the function could be something like:

if not group_spatially and not group_temporally:
    metrics = calc_metrics(data)
    # need to convert to dataframe and transpose so it looks like the
    # others
    metrics = pd.DataFrame(metrics).T
elif group_spatially and not group_temporally:
    metrics = data.groupby(spatial_idx_name).apply(calc_metrics).reset_index()
elif not group_spatially and group_temporally:
    metrics = data.groupby(pd.Grouper(index=time_idx_name, freq=group_temporally)).apply(calc_metrics).reset_index()
elif group_spatially and group_temporally:
    metrics = data.groupby([pd.Grouper(index=time_idx_name, freq=group_temporally),
                            pd.Grouper(index=spatial_idx_name)]
                            ).apply(calc_metrics).reset_index()

I think that should work. We'd just have to document how the group_temporally argument needs to work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely a bigger change, but I think it would be worth trying. It would also require propagating the change all the way up including any Snakefiles that are using this function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also resolve #212

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that is much cleaner. I think one more argument would be needed for how to do temporal aggregation. I think what you've programmed would compute metrics for native timestep only (daily). I used a sum of the daily data to get biweekly, monthly, and yearly timesteps and compute metrics for those. Let me try out this edit because it will make the code much cleaner

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed this change in the most recent commit. I needed 4 group args to replicate the previous functionality. I think it's more generic now to using different timesteps of aggregation, but it's more difficult to understand how to apply the 4 args to compute the desired metrics. For reference, here is the function that I used in the snakemake file to assign the 4 new args to this function to compute different metrics for different groups. See the description of the 4 args in the function docstring.

#Order in the list is: 
#group_spatially (bool), group_temporally (False or timestep to use), sum_aggregation (bool), site_based (bool)
def get_grp_arg(wildcards):
	if wildcards.metric_type == 'overall':
		return [False, False, False, False]
	elif wildcards.metric_type == 'month':
		return [False, 'M', False, False]
	elif wildcards.metric_type == 'reach':
		return [True, False, False, False]
	elif wildcards.metric_type == 'month_reach':
		return [True, 'M', False, False]
	elif wildcards.metric_type == 'monthly_site_based':
		return [False, 'M', True, True]
	elif wildcards.metric_type == 'monthly_all_sites':
		return [False, 'M', True, False]
	elif wildcards.metric_type == 'monthly_reach':
		return [True, 'M', True, False]

metrics = (
data.groupby(
data[time_idx_name].dt.year)
.apply(calc_metrics)
.reset_index()
)
elif group == ["seg_id_nat", "year"]:
metrics = (
data.groupby(
[data[time_idx_name].dt.year,
spatial_idx_name])
.apply(calc_metrics)
.reset_index()
)
elif group == "biweekly":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the addition of biweekly and yearly options for the metrics is great.

If I'm reading this correctly, "biweekly", "monthly", and "yearly" all also use "seg_id_nat". For consistency with the other grouped metrics, it seems good to have that included in the group list. (so group = ['seg_id_nat','biweekly'])

(and as an aside, I'm noticing we should remove the hardcoded reference to seg_id_nat and replace it with spatial_idx_name. I think it's just the 3 references in this section. Would you want to fix that in this PR since you're already editing this section?)

Also, without running (which I haven't done) I'm not sure how monthly and yearly are different from ['seg_id_nat','month'] and ['seg_id_nat','year'] since they are both grouping on the same things.

Copy link
Member Author

@jds485 jds485 Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how monthly and yearly are different from ['seg_id_nat','month'] and ['seg_id_nat','year'] since they are both grouping on the same things

The biweekly, monthly and yearly options are resampling the daily timeseries to those time steps by taking the sum of the data within those time periods (only for the days with observations). I'm not sure that sum is the best option and am open to other suggestions.

The resulting performance metrics are computed over all reaches, not by reach as with the ['seg_id_nat','time'] options, so I can add the group option that reports these metrics by reach

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with removing the "seg_id_nat" comparison to be more generic, but that will affect snakemake workflows. For example, the workflow examples all have a function that defines group using seg_id_nat. Might be better to address this problem in a separate issue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The resulting performance metrics are computed over all reaches, not by reach as with the ['seg_id_nat','time'] options, so I can add the group option that reports these metrics by reach

I'm not super familiar with the pandas grouper (so maybe that's the source of my confusion), but both monthly and yearly use 2 pandas groupers, 1 on time_idx_name and one on spatial_idx_name, right? So are you summing by reach and then calculating metrics across all the reaches?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are you summing by reach and then calculating metrics across all the reaches?

yes, that's right

Copy link
Member Author

@jds485 jds485 Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add the group option that reports these metrics by reach

Edit: I added metrics for reach-biweekly, reach-monthly and reach-yearly timeseries.

We could also have reach-biweekly-month (summarize the biweekly timeseries by month), reach-biweekly-year, and reach-monthly-year. reach-biweekly-time would require an additional function to define a biweekly index for Python datetime objects.

#filter the data to remove nans before computing the sum
#so that the same days are being summed in the month.
data_calc = (data_multiind.dropna()
.groupby(
[pd.Grouper(level=time_idx_name, freq='2W'),
pd.Grouper(level=spatial_idx_name)])
.sum()
)
metrics = calc_metrics(data_calc)
metrics = pd.DataFrame(metrics).T
elif group == "monthly":
#filter the data to remove nans before computing the sum
#so that the same days are being summed in the month.
data_calc = (data_multiind.dropna()
.groupby(
[pd.Grouper(level=time_idx_name, freq='M'),
pd.Grouper(level=spatial_idx_name)])
.sum()
)
metrics = calc_metrics(data_calc)
metrics = pd.DataFrame(metrics).T
elif group == "yearly":
#filter the data to remove nans before computing the sum
#so that the same days are being summed in the year.
data_calc = (data_multiind.dropna()
.groupby(
[pd.Grouper(level=time_idx_name, freq='Y'),
pd.Grouper(level=spatial_idx_name)])
.sum()
)
metrics = calc_metrics(data_calc)
metrics = pd.DataFrame(metrics).T
else:
raise ValueError("group value not valid")

Expand Down Expand Up @@ -356,7 +469,7 @@ def combined_metrics(
group=group,
val_sites = val_sites,
test_sites = test_sites,
train_sites=train_sites)
train_sites = train_sites)
df_all.extend([metrics])

df_all = pd.concat(df_all, axis=0)
Expand Down
36 changes: 33 additions & 3 deletions river_dl/preproc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ def prep_y_data(
test_end_date=None,
val_sites=None,
test_sites=None,
strict_spatial_partition=False,
spatial_idx_name="seg_id_nat",
time_idx_name="date",
seq_len=365,
Expand Down Expand Up @@ -632,9 +633,12 @@ def prep_y_data(
:param test_end_date: [str or list] fmt: "YYYY-MM-DD"; date(s) to end test
period (can have multiple discontinuous periods)
:param val_sites: [list of site_ids] sites to retain for validation. These
sites will be witheld from training
sites will be witheld from training and testing
:param test_sites: [list of site_ids] sites to retain for testing. These
sites will be witheld from training and validation
:param strict_spatial_partition: [bool] when True, the test set does not
jds485 marked this conversation as resolved.
Show resolved Hide resolved
contain any reaches that are used in training or validation, and the
validation set does not contain any reaches that are used in training or testing.
:param seq_len: [int] length of sequences (e.g., 365)
:param log_vars: [list-like] which variables_to_log (if any) to take log of
:param exclude_file: [str] path to exclude file
Expand Down Expand Up @@ -674,10 +678,15 @@ def prep_y_data(
# replace validation sites' (and test sites') data with np.nan
if val_sites:
y_trn = y_trn.where(~y_trn[spatial_idx_name].isin(val_sites))
y_tst = y_tst.where(~y_tst[spatial_idx_name].isin(val_sites))
if strict_spatial_partition:
y_val = y_val.where(y_val[spatial_idx_name].isin(val_sites))

if test_sites:
y_trn = y_trn.where(~y_trn[spatial_idx_name].isin(test_sites))
y_val = y_val.where(~y_val[spatial_idx_name].isin(test_sites))
if strict_spatial_partition:
y_tst = y_tst.where(y_tst[spatial_idx_name].isin(test_sites))


if log_vars:
Expand Down Expand Up @@ -757,6 +766,7 @@ def prep_all_data(
test_end_date=None,
val_sites=None,
test_sites=None,
strict_spatial_partition=False,
y_vars_finetune=None,
y_vars_pretrain=None,
spatial_idx_name="seg_id_nat",
Expand All @@ -772,6 +782,8 @@ def prep_all_data(
log_y_vars=False,
out_file=None,
segs=None,
earliest_time=None,
latest_time=None,
normalize_y=True,
trn_offset = 1.0,
tst_val_offset = 1.0
Expand All @@ -782,7 +794,9 @@ def prep_all_data(
scaled to have a std of 1 and a mean of zero
:param x_data_file: [str] path to Zarr file with x data. Data should have
a spatial coordinate and a time coordinate that are specified in the
`spatial_idx_name` and `time_idx_name` arguments
`spatial_idx_name` and `time_idx_name` arguments. Assumes that all spaces will be used,
unless segs is specified. Assumes all times will be used,
unless an earliest_time or latest_time is specified.
:param y_data_file: [str] observations Zarr file. Data should have a spatial
coordinate and a time coordinate that are specified in the
spatial_idx_name` and `time_idx_name` arguments
Expand All @@ -802,6 +816,9 @@ def prep_all_data(
sites will be witheld from training
:param test_sites: [list of site_ids] sites to retain for testing. These
sites will be witheld from training and validation
:param strict_spatial_partition: [bool] when True, the test set does not
jds485 marked this conversation as resolved.
Show resolved Hide resolved
contain any reaches that are used in training or validation, and the
validation set does not contain any reaches that are used in training or testing.
:param spatial_idx_name: [str] name of column that is used for spatial
index (e.g., 'seg_id_nat')
:param time_idx_name: [str] name of column that is used for temporal index
Expand All @@ -827,6 +844,8 @@ def prep_all_data(
:param log_y_vars: [bool] whether or not to take the log of discharge in
training
:param segs: [list-like] which segments to prepare the data for
:param earliest_time: [str] when specified, filters the x_data to remove earlier times
:param latest_time: [str] when specified, filters the x_data to remove later times
:param normalize_y: [bool] whether or not to normalize the y_dataset values
:param out_file: [str] file to where the values will be written
:param trn_offset: [str] value for the training offset
Expand Down Expand Up @@ -867,15 +886,25 @@ def prep_all_data(

if segs:
x_data = x_data.sel({spatial_idx_name: segs})

if earliest_time:
mask_etime = (x_data[time_idx_name] >= np.datetime64(earliest_time))
x_data = x_data.where(mask_etime, drop=True)

if latest_time:
mask_ltime = (x_data[time_idx_name] <= np.datetime64(latest_time))
x_data = x_data.where(mask_ltime, drop=True)

x_data = x_data[x_vars]

if catch_prop_file:
x_data = prep_catch_props(x_data, catch_prop_file, catch_prop_vars, spatial_idx_name)
#update the list of x_vars
x_vars = list(x_data.data_vars)

# make sure we don't have any weird or missing input values
check_if_finite(x_data)

x_trn, x_val, x_tst = separate_trn_tst(
x_data,
time_idx_name,
Expand All @@ -886,7 +915,7 @@ def prep_all_data(
test_start_date,
test_end_date,
)

x_trn_scl, x_std, x_mean = scale(x_trn)

x_scl, _, _ = scale(x_data,std=x_std,mean=x_mean)
Expand Down Expand Up @@ -1001,6 +1030,7 @@ def prep_all_data(
test_end_date=test_end_date,
val_sites=val_sites,
test_sites=test_sites,
strict_spatial_partition=strict_spatial_partition,
spatial_idx_name=spatial_idx_name,
time_idx_name=time_idx_name,
seq_len=seq_len,
Expand Down