This repository has been archived by the owner on Jun 2, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathSnakefile_b_ml_model_baseline
146 lines (135 loc) · 8.35 KB
/
Snakefile_b_ml_model_baseline
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import importlib
import os
import sys
from river_dl.preproc_utils import prep_all_data
sys.path.append('03b_model/src')
import run_model
configfile: "01_fetch/wildcards_fetch_config.yaml"
configfile: "03b_model/model_config.yaml"
usgs_nwis_sites = config["fetch_usgs_nwis.py"]["sites"]
noaa_nos_sites = config["fetch_noaa_nos.py"]["sites"]
noaa_nos_products = config["fetch_noaa_nos.py"]["products"]
noaa_nerr_site = config["fetch_noaa_nerrs.py"]["sites"]
replicates = [str(x).rjust(2,'0') for x in range(config['replicates'])]
wildcard_constraints:
noaa_nos_sites="\d+"
include: "Snakefile_fetch_munge"
###########################################################################################
#RULE ALL
###########################################################################################
rule all_b_ml_model_baseline:
input:
expand("02_munge/out/D/usgs_nwis_{usgs_nwis_site}.csv", usgs_nwis_site = usgs_nwis_sites),
expand("02_munge/out/daily_summaries/noaa_nos_{noaa_nos_site}.csv", noaa_nos_site = noaa_nos_sites),
expand("02_munge/out/D/noaa_nerrs_{noaa_nerr_site}.csv", noaa_nerr_site = noaa_nerr_site),
"03b_model/out/inputs.zarr",
"03b_model/out/target.zarr",
"03b_model/out/prepped_model_io_data",
os.path.join(config['out_dir'], config['run_id'], 'model_param_output.txt'),
os.path.join(config['out_dir'], config['run_id'], 'losses.png'),
os.path.join(config['out_dir'], config['run_id'], 'weights.pt'),
os.path.join(config['out_dir'],config['run_id'],"ModelResults.csv"),
os.path.join(config['out_dir'],config['run_id'],"ModelResultsTimeSeries.png")
##---------------------------------------------------------------------------------------##
###########################################################################################
#PREPARE THE INPUT DATA
###########################################################################################
rule prepare_inputs_targets:
input:
"03b_model/model_config.yaml",
expand("02_munge/out/D/usgs_nwis_{usgs_nwis_site}.csv", usgs_nwis_site = usgs_nwis_sites),
expand("02_munge/out/daily_summaries/noaa_nos_{noaa_nos_site}.csv", noaa_nos_site = noaa_nos_sites),
expand("02_munge/out/D/noaa_nerrs_{noaa_nerr_site}.csv", noaa_nerr_site = noaa_nerr_site)
output:
directory("03b_model/out/inputs.zarr"),
directory("03b_model/out/target.zarr"),
"03b_model/out/prepped_model_io_data"
run:
inputs_xarray, target_xarray = run_model.select_inputs_targets(inputs = config['inputs'],
target = config['target'],
train_start_date = config['train_start_date'],
test_end_date = config['test_end_date'],
out_dir = config['out_dir'],
inc_ante = config['include_antecedant_data'])
run_model.prep_input_target_data(inputs_xarray = inputs_xarray, target_xarray = target_xarray,
train_start_date = config['train_start_date'],
train_end_date = config['train_end_date'],
val_start_date = config['val_start_date'],
val_end_date = config['val_end_date'],
test_start_date = config['test_start_date'],
test_end_date = config['test_end_date'],
seq_len = config['seq_len'],
offset = config['offset'],
out_dir = config['out_dir'])
##---------------------------------------------------------------------------------------##
###########################################################################################
#TRAIN THE MODEL
###########################################################################################
rule train_model:
input:
"03b_model/model_config.yaml",
os.path.join(config['out_dir'],'prepped_model_io_data')
output:
os.path.join(config['out_dir'], config['run_id'], 'model_param_output.txt'),
os.path.join(config['out_dir'], config['run_id'], 'losses.png'),
os.path.join(config['out_dir'], config['run_id'], 'weights.pt')
run:
run_model.train_model(prepped_model_io_data_file = input[1],
inputs = config['inputs'],
seq_len = config['seq_len'],
hidden_units = config['hidden_units'],
recur_dropout = config['recur_dropout'],
dropout = config['dropout'],
n_epochs = config['n_epochs'],
learn_rate = config['learn_rate'],
out_dir = config['out_dir'],
run_id = config['run_id'],
train_start_date = config['train_start_date'],
train_end_date = config['train_end_date'],
val_start_date = config['val_start_date'],
val_end_date = config['val_end_date'],
test_start_date = config['test_start_date'],
test_end_date = config['test_end_date'],
inc_ante = config['include_antecedant_data'])
##---------------------------------------------------------------------------------------##
###########################################################################################
#MAKE PLOT AND SAVE MODEL PREDICTIONS
###########################################################################################
rule make_plot_save_predictions:
input:
"03b_model/model_config.yaml",
os.path.join(config['out_dir'],'prepped_model_io_data')
output:
os.path.join(config['out_dir'],config['run_id'],"ModelResults.csv"),
os.path.join(config['out_dir'],config['run_id'],"ModelResultsTimeSeries.png")
run:
predictions = run_model.make_predictions(prepped_model_io_data_file = input[1],
hidden_units = config['hidden_units'],
recur_dropout = config['recur_dropout'],
dropout = config['dropout'],
n_epochs = config['n_epochs'],
learn_rate = config['learn_rate'],
out_dir = config['out_dir'],
run_id = config['run_id'],
train_start_date = config['train_start_date'],
train_end_date = config['train_end_date'],
val_start_date = config['val_start_date'],
val_end_date = config['val_end_date'],
test_start_date = config['test_start_date'],
test_end_date = config['test_end_date'])
run_model.plot_save_predictions(predictions, config['out_dir'], config['run_id'])
##---------------------------------------------------------------------------------------##
###########################################################################################
#RUN REPLICATES
###########################################################################################
rule run_replicates:
input:
"03b_model/model_config.yaml",
#os.path.join(config['out_dir'],'prepped_model_io_data')
output:
expand(os.path.join(config['out_dir'], config['run_id'], '{rep}/model_param_output.txt'), rep = replicates),
expand(os.path.join(config['out_dir'], config['run_id'], '{rep}/losses.png'), rep = replicates),
expand(os.path.join(config['out_dir'], config['run_id'], '{rep}/weights.pt'), rep = replicates)
run:
run_model.run_replicates(config['replicates'], os.path.join(config['out_dir'],'prepped_model_io_data'))
##---------------------------------------------------------------------------------------##