-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_replay_important_v2.py
127 lines (93 loc) · 3.48 KB
/
run_replay_important_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
import torch
import numpy as np
from smart_home_dataset import SmartHomeDataset
from classifier import Classifier
from torch import optim
import utils
import callbacks as cb
import time
from generative_replay_learner import GenerativeReplayLearner;
import arg_params
import json
import os
import copy
import torch.multiprocessing as mp
from run_main import *
def select_hidden_unit(args):
if args.data_dir == "pamap":
args.hidden_units = 1000
elif args.data_dir == "dsads":
args.hidden_units = 1000
elif args.data_dir == "housea":
args.hidden_units = 200
else:
args.hidden_units = 500
return args.hidden_units
if __name__ == "__main__":
parser = arg_params.get_parser()
args = parser.parse_args()
print("Arguments")
print(args)
result_folder = args.results_dir
print("\n")
print("STEP1: load datasets")
base_dataset = select_dataset(args)
methods = [
("offline", 0), ("sg-cgan", 0), #("mp-gan", 0), ("mp-wgan", 0), ("sg-cwgan", 0),
("offline", 1), ("sg-cgan", 1), #("mp-gan", 1), ("mp-wgan", 1), ("sg-cwgan", 1),
("offline", 2), ("sg-cgan", 2), #("mp-gan", 2), ("mp-wgan", 2), ("sg-cwgan", 2),
("offline", 3), ("sg-cgan", 3), #("mp-gan", 3), ("mp-wgan", 3), ("sg-cwgan", 3),
("offline", 4), ("sg-cgan", 4), #("mp-gan", 4), ("mp-wgan", 4), ("sg-cwgan", 4),
]
jobs = []
# pool = mp.Pool()
start = time.time()
ntask = 10
tasks = []
if args.task_order is not None:
ft = open(args.task_order)
tasks = [line.strip().split(";") for line in ft]
base_args = args
for task_order in range(ntask):
if args.task_order is not None:
base_dataset.permu_task_order(tasks[task_order])
else:
base_dataset.permu_task_order()
identity = {
"task_order": None,
"method": None,
"train_session": None,
"task_index": None,
"no_of_test": None,
"no_of_correct_prediction": None,
"accuracy": None,
"solver_training_time": None,
"generator_training_time": None,
}
identity["task_order"] = task_order
if args.task_order is None:
save_order(result_folder, task_order, base_dataset.classes)
traindata, testdata = base_dataset.train_test_split()
dataset = traindata
if args.oversampling:
dataset = traindata.resampling()
train_datasets, config, classes_per_task = dataset.split(tasks=args.tasks)
test_datasets, _, _ = testdata.split(tasks=args.tasks)
print("******* Run ",task_order,"*******")
print("\n")
for method in methods:
m, cmd = method
identity["method"] = m
args = copy.deepcopy(base_args)
args.rnt = (cmd)*0.25
args.critic_fc_units = select_hidden_unit(args)
args.generator_fc_units = select_hidden_unit(args)
args.g_iters = get_g_iter(m, None)
run_model(identity, method, args, config, train_datasets, test_datasets, True)
# pool.apply_async(run_model, args=(identity, method, args, config, train_datasets, test_datasets, False))
# pool.close()
# pool.join()
training_time = time.time() - start
print(training_time)
# clearup_tmp_file(result_folder, ntask, methods)