-
Notifications
You must be signed in to change notification settings - Fork 15
/
main.py
122 lines (88 loc) · 4.05 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import logging
import pickle
import json
import torch
import os.path as osp
import warnings
warnings.filterwarnings('ignore')
from methods import ProDesign
from API import Recorder
from utils import *
class Exp:
def __init__(self, args, show_params=True):
self.args = args
self.config = args.__dict__
self.device = self._acquire_device()
self.total_step = 0
self._preparation()
if show_params:
print_log(output_namespace(self.args))
def _acquire_device(self):
if self.args.use_gpu:
device = torch.device('cuda:0')
print('Use GPU:',device)
else:
device = torch.device('cpu')
print('Use CPU')
return device
def _preparation(self):
set_seed(self.args.seed)
# log and checkpoint
self.path = osp.join(self.args.res_dir, self.args.ex_name)
check_dir(self.path)
self.checkpoints_path = osp.join(self.path, 'checkpoints')
check_dir(self.checkpoints_path)
sv_param = osp.join(self.path, 'model_param.json')
with open(sv_param, 'w') as file_obj:
json.dump(self.args.__dict__, file_obj)
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.basicConfig(level=logging.INFO, filename=osp.join(self.path, 'log.log'),
filemode='a', format='%(asctime)s - %(message)s')
self._get_data()
self._build_method()
def _build_method(self):
steps_per_epoch = 1000
if self.args.method == 'ProDesign':
self.method = ProDesign(self.args, self.device, steps_per_epoch)
def _get_data(self):
self.train_loader, self.valid_loader, self.test_loader = get_dataset(self.config)
def train(self):
recorder = Recorder(self.args.patience, verbose=True)
for epoch in range(self.args.epoch):
train_loss, train_perplexity = self.method.train_one_epoch(self.train_loader)
if epoch % self.args.log_step == 0:
with torch.no_grad():
valid_loss, valid_perplexity = self.valid()
# self._save(name=str(epoch))
self.test()
print_log('Epoch: {0}, Steps: {1} | Train Loss: {2:.4f} Train Perp: {3:.4f} Valid Loss: {4:.4f} Valid Perp: {5:.4f}\n'.format(epoch + 1, len(self.train_loader), train_loss, train_perplexity, valid_loss, valid_perplexity))
recorder(valid_loss, self.method.model, self.path)
if recorder.early_stop:
print("Early stopping")
logging.info("Early stopping")
break
best_model_path = osp.join(self.path, 'checkpoint.pth')
self.method.model.load_state_dict(torch.load(best_model_path))
def valid(self):
valid_loss, valid_perplexity = self.method.valid_one_epoch(self.valid_loader)
print_log('Valid Perp: {0:.4f}'.format(valid_perplexity))
return valid_loss, valid_perplexity
def test(self):
test_perplexity, test_recovery, test_subcat_recovery = self.method.test_one_epoch(self.test_loader)
print_log('Test Perp: {0:.4f}, Test Rec: {1:.4f}\n'.format(test_perplexity, test_recovery))
for cat in test_subcat_recovery.keys():
print_log('Category {0} Rec: {1:.4f}\n'.format(cat, test_subcat_recovery[cat]))
return test_perplexity, test_recovery
if __name__ == '__main__':
from parser import create_parser
args = create_parser()
config = args.__dict__
print(config)
exp = Exp(args)
# svpath = '/gaozhangyang/experiments/ProDesign/results/ProDesign/'
# exp.method.model.load_state_dict(torch.load(svpath+'checkpoint.pth'))
print('>>>>>>>>>>>>>>>>>>>>>>>>>> training <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
exp.train()
print('>>>>>>>>>>>>>>>>>>>>>>>>>> testing <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
test_perp, test_rec = exp.test()