-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
188 lines (172 loc) · 6.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#!/usr/bin/env python
"""
train ALT
Created by anonymous on 2021-11-21
"""
import os
import sys
sys.path.append(os.path.abspath(''))
import random
import argparse
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import dataloader
from lib.datasets import get_dataset, pacs
from lib.datasets.transforms import GreyToColor, IdentityTransform, ToGrayScale, LaplacianOfGaussianFiltering
import torchvision.transforms as transforms
from trainer_alt import *
from lib.networks import get_network
from metann import Learner
def main(args):
# GPU and random seed
print("Random Seed: ", args.rand_seed)
if args.rand_seed is not None:
random.seed(args.rand_seed)
torch.manual_seed(args.rand_seed)
print(args.gpu_ids, type(args.gpu_ids))
if type(args.gpu_ids) is list and len(args.gpu_ids) >= 0:
torch.cuda.manual_seed_all(args.rand_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.rand_seed)
torch.set_num_threads(1)
# DATALOADERS
if args.data_name=='pacs':
assert args.n_classes==7
data_dir = './data/PACS/'
domains = ['photo', 'art_painting', 'cartoon', 'sketch']
trg_domains = [dd for dd in domains if dd!=args.source]
stats = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
elif args.data_name=='digits':
assert args.n_classes==10
data_dir = "./data/"
domains = ['mnist10k', 'mnist_m', 'svhn', 'usps', 'synth']
trg_domains = [dd for dd in domains if dd!=args.source]
stats = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
elif args.data_name=='officehome':
assert args.n_classes==65
data_dir = './data/OfficeHome/'
domains = ['real', 'art', 'clipart', 'product']
trg_domains = [dd for dd in domains if dd!=args.source]
stats = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
print(args.data_name)
print("SRC:{}; TRG:{}".format(args.source, domains))
# transforms
trans_list = []
trans_list.append(
transforms.RandomResizedCrop(args.image_size, scale=(0.5, 1))
)
if args.colorjitter:
trans_list.append(transforms.ColorJitter(*[args.colorjitter] * 4))
if args.data_name != 'digits':
trans_list.append(transforms.RandomHorizontalFlip())
trans_list.append(transforms.ToTensor())
if args.data_name=='digits':
trans_list.append(GreyToColor())
trans_list.append(transforms.Normalize(*stats))
train_transform = transforms.Compose(trans_list)
test_transform = transforms.Compose([
transforms.Resize(args.image_size),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
GreyToColor() if args.data_name=='digits' else IdentityTransform(),
transforms.Normalize(*stats)
])
## datasets
print("\n=========Preparing Data=========")
assert args.source in domains, 'allowed data_name {}'.format(domains)
if args.data_name=='pacs':
trainset = pacs.PACS(
root=data_dir, domain=args.source,
split='train', transform=train_transform
)
validsets = {}
for dd in domains:
dd_validset = pacs.PACS(
root=data_dir, domain=dd,
split='crossval', transform=test_transform
)
validsets[dd] = dd_validset
testsets={}
for dd in trg_domains:
dd_testset = pacs.PACS(
root=data_dir, domain=dd,
split='test', transform=test_transform
)
testsets[dd] = dd_testset
# add the source crossval as a test too
testsets[args.source] = validsets[args.source]
elif args.data_name=='officehome':
sourceset = ImageFolder(
os.path.join(data_dir, args.source),
transform=train_transform
)
trainset, src_validset = torch.utils.data.random_split(
sourceset, [
int(0.9*len(sourceset)),
len(sourceset) - int(0.9*len(sourceset))
],
generator=torch.Generator().manual_seed(381)
)
validsets = {}
testsets = {}
for dd in trg_domains:
dd_set = ImageFolder(
os.path.join(data_dir, dd), transform=test_transform)
dd_validset, dd_testset = torch.utils.data.random_split(
dd_set,
[int(0.1*len(dd_set)), len(dd_set) - int(0.1*len(dd_set))],
generator=torch.Generator().manual_seed(381)
)
validsets[dd] = dd_validset
testsets[dd] = dd_testset
validsets[args.source] = src_validset
testsets[args.source] = src_validset
elif args.data_name=='digits':
trainset = get_dataset(
args.source, root=data_dir, train=True, download=True,
transform=train_transform
)
validsets = {
domain: get_dataset(
domain, root=data_dir, train=False, download=True, transform=test_transform) for domain in domains}
testsets = validsets
trainloaders = [
torch.utils.data.DataLoader(
trainset, batch_size=args.batch_size, shuffle=True,
num_workers=8
)
]
validloaders = {
d: torch.utils.data.DataLoader(
validsets[d], batch_size=args.batch_size, shuffle=False,
num_workers=2
) for d in validsets.keys()
}
testloaders = {
d: torch.utils.data.DataLoader(
testsets[d], batch_size=args.batch_size, shuffle=False,
num_workers=2
) for d in testsets.keys()
}
# MODEL
print("\n=========Building Model=========")
net = Learner(
get_network(
args.net, num_classes=args.n_classes, pretrained=True,
drop=args.drop
)
)
trainer = ALT(args)
trainer.train(
net,
trainset,
trainloaders, validloaders, testloaders=testloaders,
data_mean=(0.5, 0.5, 0.5), data_std=((0.5, 0.5, 0.5))
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
add_basic_args(parser)
args = parser.parse_args()
main(args)