-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_demo.py
107 lines (71 loc) · 2.29 KB
/
run_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
import torch
import numpy as np
from smart_home_dataset import SmartHomeDataset
from classifier import Classifier
from torch import optim
import utils
import callbacks as cb
import time
from generative_replay_learner import GenerativeReplayLearner;
import arg_params
import json
import os
import copy
import torch.multiprocessing as mp
from run_main import *
def get_g_iter(method, cmd=None):
return 1000*int(cmd)
def get_hidden_unit(args):
if args.data_dir == "pamap":
return 500
elif args.data_dir == "dsads":
return 1000
elif args.data_dir == "housea":
return 100
else:
return 100
if __name__ == "__main__":
parser = arg_params.get_parser()
args = parser.parse_args()
print("Arguments")
print(args)
result_folder = args.results_dir
print("\n")
print("STEP1: load datasets")
base_dataset = select_dataset(args)
methods = [
("none", 0), ("mp-gan", 0), ("sg-cgan", 0)
]
identity = {
"task_order": None,
"method": None,
"train_session": None,
"task_index": None,
"no_of_test": None,
"no_of_correct_prediction": None,
"accuracy": None,
"solver_training_time": None,
"generator_training_time": None,
}
identity["task_order"] = 1
traindata, testdata = base_dataset.train_test_split()
dataset = traindata
if args.oversampling:
dataset = traindata.resampling()
train_datasets, config, classes_per_task = dataset.split(tasks=args.tasks)
test_datasets, _, _ = testdata.split(tasks=args.tasks)
base_args = args
for method in methods:
start = time.time()
m, cmd = method
identity["method"] = m
args = copy.deepcopy(base_args)
args.critic_fc_units = get_hidden_unit(args)
args.generator_fc_units = get_hidden_unit(args)
args.g_iters = get_g_iter(m, cmd+1)
env_name = "Continual learning ["+m+"]"
visdom = {'env': env_name, 'graph': "models"}
run_model(identity, method, args, config, train_datasets, test_datasets, verbose=True, visdom=visdom)
training_time = time.time() - start
print("Training Time", training_time)