-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathmain.py
112 lines (98 loc) · 4.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
import os
from configs import get_best_config
from src.evaluation.evaluator import analyse_from_pkls
from src.evaluation.logger_config import init_logging
from src.evaluation.trainer import Trainer
from configs import datasets_config, thres_methods, get_thres_config
from src.evaluation.evaluation_utils import get_algo_class, get_dataset_class
def run_multi_seeds(out_dir_root, multi_seeds, ds_to_run, algos_to_run, test_run=False):
for seed in multi_seeds:
for ds_name in ds_to_run:
for algo_name in algos_to_run:
algo_config_dict = get_best_config(algo_name=algo_name, ds_name=ds_name)
if test_run:
if "num_epochs" in algo_config_dict.keys():
algo_config_dict["num_epochs"] = 1
out_dir_algo = os.path.join(out_dir_root, algo_name)
train_analyse_algo(ds_name=ds_name, algo_name=algo_name, algo_config_dict=algo_config_dict,
out_dir_algo=out_dir_algo, seed=seed)
def train_analyse_algo(ds_name, algo_name, algo_config_dict, out_dir_algo, seed):
init_logging(os.path.join(out_dir_algo, 'logs'))
logger = logging.getLogger(__name__)
if ds_name in datasets_config.keys():
ds_kwargs = datasets_config[ds_name]
else:
ds_kwargs = {}
trainer = Trainer(ds_class=get_dataset_class(ds_name),
algo_seeds=[seed],
algo_class=get_algo_class(algo_name),
ds_seed=seed,
ds_kwargs=ds_kwargs,
algo_config_base=algo_config_dict,
output_dir=out_dir_algo,
logger=logger)
print(
"Training algo {} on dataset {} with config {} and seed {}".format(algo_name, ds_name, algo_config_dict, seed))
trainer.train_predict()
analyse_from_pkls(results_root=out_dir_algo, thres_methods=thres_methods, eval_root_cause=True, point_adjust=False,
eval_dyn=True, eval_R_model=True, thres_config=get_thres_config,
telem_only=True, composite_best_f1=True)
def run_all_benchmarks(out_dir_root):
multi_seeds = [0, 1, 2, 3, 4]
ds_to_run = ["swat", "damadics-s", "wadi", "msl", "smap", "smd", "skab"]
algos_to_run = [
"RawSignalBaseline",
"PcaRecons",
"UnivarAutoEncoder_recon_all",
"AutoEncoder_recon_all",
"LSTM-ED_recon_all",
"TcnED",
"VAE-LSTM",
"MSCRED",
"OmniAnoAlgo"
]
run_multi_seeds(out_dir_root=out_dir_root,
multi_seeds=multi_seeds,
ds_to_run=ds_to_run,
algos_to_run=algos_to_run,
test_run=False)
def run_quick_trial_5_ds(out_dir_root):
multi_seeds = [0]
ds_to_run = [
"damadics-s",
"msl",
"smap",
"smd",
"skab"
]
algos_to_run = ["RawSignalBaseline"]
run_multi_seeds(out_dir_root=out_dir_root,
multi_seeds=multi_seeds,
ds_to_run=ds_to_run,
algos_to_run=algos_to_run,
test_run=True)
def run_quick_trial_all_algos(out_dir_root):
multi_seeds = [0]
ds_to_run = ["skab"]
algos_to_run = [
"RawSignalBaseline",
"PcaRecons",
"UnivarAutoEncoder_recon_all",
"AutoEncoder_recon_all",
"LSTM-ED_recon_all",
"TcnED",
"VAE-LSTM",
"MSCRED",
"OmniAnoAlgo"
]
run_multi_seeds(out_dir_root=out_dir_root,
multi_seeds=multi_seeds,
ds_to_run=ds_to_run,
algos_to_run=algos_to_run,
test_run=True)
if __name__ == "__main__":
out_dir_root = os.path.join(os.getcwd(), "reports", "trial")
run_quick_trial_all_algos(out_dir_root)
# run_quick_trial_5_ds(out_dir_root)
# run_all_benchmarks(out_dir_root)