-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
45 lines (35 loc) · 1.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import sys
import os
import argparse
import torch
from data import *
from Backup import *
from solver import Solver
# fix random seed
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
def main(args, model):
torch.backends.cudnn.benchmark = True
tr_dataset = TrainDataset(json_dir= args.json_dir,
batch_size = args.batch_size)
cv_data_set = CvDataset(json_dir= args.json_dir,
batch_size= args.batch_size)
tr_loader = TrainDataLoader(data_set= tr_dataset,
batch_size = 1,
num_workers= args.num_workers,
pin_memory=True)
cv_loader = CvDataLoader(data_set= cv_data_set,
batch_size = 1,
num_workers=args.num_workers,
pin_memory=True)
data= {'tr_loader': tr_loader, 'cv_loader': cv_loader}
print(model)
# count the parameter number of the network
summary(model)
model.cuda()
optimizer = torch.optim.Adam(model.parameters(),
lr = args.lr,
weight_decay = args.l2)
solver = Solver(data, model, optimizer, args)
solver.train()