forked from tajo/deeplearning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
81 lines (73 loc) · 1.94 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
'''
@author Vojtech Miksu <[email protected]>
This is the main script, it runs training, validating and testing for all datasets based on
preferred parameters.
'''
import datetime
import time
import csv
import os
import json
import sys
from libs import SdA
'''
iterates over the tasks in json config
'''
def run(tasks):
task_num = 1
for task_params in tasks:
print '##################### PROCESSING TASK #{} #####################'.format(task_num)
run_task(task_params)
task_num += 1
'''
prepare and run single task
'''
def run_task(params):
if not os.path.isfile(params['logfile']) or os.stat(params['logfile'])[6] == 0:
with open(params['logfile'], "wb") as myfile:
data = ['date_time',
'dataset',
'target_name',
'finetune_lr',
'pretraining_epochs',
'pretrain_lr',
'training_epochs',
'batch_size',
'n_ins',
'n_outs',
'hidden_layers_sizes',
'corruption_levels (%)',
'valid_perf (%)',
'test_perf (%)',
'test_recall',
'run_time (min)']
writer = csv.writer(myfile, delimiter=',')
writer.writerow(data)
# number of target columns in dataset, so SdA can split them from the rest
try:
targets = params['targets']
except:
targets = 3 # default for stock datasets
for dataset in params['datasets']:
SdA.test_SdA(params['finetune_lr'],
params['target_name'],
params['pretraining_epochs'],
params['pretrain_lr'],
params['training_epochs'],
dataset,
params['batch_size'],
params['n_ins'],
params['hidden_layers_sizes'],
params['n_outs'],
params['corruption_levels'],
params['logfile'],
targets)
if __name__ == '__main__':
try:
json_data=open(sys.argv[1]) # expects json config as command line parameter
except:
print 'You have to specify some valid "task.json" as the input argument'
#print 'ERROR: file ' + sys.argv[1] + ' does not exist'
sys.exit()
data = json.load(json_data)
run(data['tasks'])