-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
119 lines (104 loc) · 5.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from networks.vision_transformer import SwinUnet as ViT_seg
from trainer import trainer_synapse
from config import get_config
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='./data/REFUGE/', help='root dir for data')
parser.add_argument('--dataset', type=str,
default='REFUGE', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_REFUGE', help='list dir')
parser.add_argument('--num_classes', type=int,
default=3, help='output channel of network') #原9
parser.add_argument('--output_dir', default='./save_train/Dice10_lr0.05', type=str, help='output dir')
parser.add_argument('--max_iterations', type=int,
default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=150, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=32, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--base_lr', type=float, default=0.1,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=224, help='input patch size of network input')
parser.add_argument('--seed', type=int,
default=1234, help='random seed')
parser.add_argument('--cfg', type=str, default='./configs/swin_tiny_patch4_window7_224_lite.yaml', metavar="FILE", help='path to config file', ) #default自加 删去了required=True
# parser.add_argument('--cfg', type=str, required=False, metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
parser.add_argument('--loss_type', type=str, default='origin', help='Test throughput only')
parser.add_argument('--d_a', type=int, default=1, help='coefficient')
parser.add_argument('--d_b', type=int, default=0, help='exponent')
args = parser.parse_args()
if args.dataset == "REFUGE" or args.dataset == "RITEyes" or args.dataset == "ISIC":
args.root_path = os.path.join(args.root_path, "train_npz")
config = get_config(args)
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'REFUGE': {
'root_path': args.root_path,
'list_dir': './lists/lists_REFUGE',
'num_classes': 3,
},
'RITEyes': {
'root_path': args.root_path,
'list_dir': './lists/lists_RITEyes',
'num_classes': 3,
},
'ISIC': {
'root_path': args.root_path,
'list_dir': './lists/lists_ISIC',
'num_classes': 2,
},
}
if args.batch_size != 24 and args.batch_size % 6 == 0:
args.base_lr *= args.batch_size / 24
args.num_classes = dataset_config[dataset_name]['num_classes']
args.root_path = dataset_config[dataset_name]['root_path']
args.list_dir = dataset_config[dataset_name]['list_dir']
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
net = ViT_seg(config, img_size=args.img_size, num_classes=args.num_classes).cuda()
net.load_from(config)
trainer = {'REFUGE': trainer_synapse, 'RITEyes': trainer_synapse, 'ISIC': trainer_synapse}
trainer[dataset_name](args, net, args.output_dir)