-
Notifications
You must be signed in to change notification settings - Fork 5
/
train.py
38 lines (28 loc) · 1.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from train_denoiser import *
import sys, getopt
import argparse
data=None
name = 'test'
parser = argparse.ArgumentParser(description='Denoiser Training')
# python3 train.py --epochs --dataset --input_dim --output_dim --model_name --model_path --no_of_channels --hidden_state_dim
# Training Objective
parser.add_argument('--epochs', type=int,
help="Number of epochs")
# Dataset
parser.add_argument('--dataset', type=str, help='path to dataset of choice')
parser.add_argument('--in_nc', type=int, help='input dimensions')
parser.add_argument('--out_nc', type=int, help='output dimensions')
# Model type
parser.add_argument('--model_name', type=str, help="name of model")
parser.add_argument('--model_path', type=str, help="path to model")
# Setting
parser.add_argument('--nc', type=str, help='input should look like "265340,268738,270774,270817" ')
parser.add_argument('--h', type=int, help='dimensions of hidden state')
args = parser.parse_args()
print(args.nc)
a = [int(x) for x in args.nc.split()]
de=denoiser(in_nc = args.in_nc,out_nc=args.out_nc,nc = a, nb=args.h)
if args.model_path:
de.ld(args.model_path)
de.train_drunet(args.epochs,args.dataset)
de.drunet.save(args.model_name)