forked from SJTUzhanglj/FCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
110 lines (101 loc) · 4.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
FCN
"""
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
from dataset import SBDClassSeg, MyTestData
from transform import Colorize
from criterion import CrossEntropyLoss2d
from model import FCN8s
from myfunc import imsave
import visdom
import numpy as np
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument('--phase', type=str, default='train', help='train or test')
parser.add_argument('--param', type=str, default=None, help='path to pre-trained parameters')
parser.add_argument('--data', type=str, default='./train', help='path to input data')
parser.add_argument('--out', type=str, default='./out', help='path to output data')
opt = parser.parse_args()
print(opt)
vis = visdom.Visdom()
win0 = vis.image(torch.zeros(3, 100, 100))
win1 = vis.image(torch.zeros(3, 100, 100))
win2 = vis.image(torch.zeros(3, 100, 100))
win3 = vis.image(torch.zeros(3, 100, 100))
color_transform = Colorize()
"""parameters"""
iterNum = 30
"""data loader"""
# dataRoot = '/media/xyz/Files/data/datasets'
# checkRoot = '/media/xyz/Files/fcn8s-deconv'
dataRoot = opt.data
if not os.path.exists(opt.out):
os.mkdir(opt.out)
if opt.phase == 'train':
checkRoot = opt.out
loader = torch.utils.data.DataLoader(
SBDClassSeg(dataRoot, split='train', transform=True),
batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
else:
outputRoot = opt.out
loader = torch.utils.data.DataLoader(
MyTestData(dataRoot, transform=True),
batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
"""nets"""
model = FCN8s()
if opt.param is None:
vgg16 = torchvision.models.vgg16(pretrained=True)
model.copy_params_from_vgg16(vgg16, copy_fc8=False, init_upscore=True)
else:
model.load_state_dict(torch.load(opt.param))
criterion = CrossEntropyLoss2d()
optimizer = torch.optim.Adam(model.parameters(), 0.0001, betas=(0.5, 0.999))
model = model.cuda()
if opt.phase == 'train':
"""train"""
for it in range(iterNum):
epoch_loss = []
for ib, data in enumerate(loader):
inputs = Variable(data[0]).cuda()
targets = Variable(data[1]).cuda()
model.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
epoch_loss.append(loss.data[0])
loss.backward()
optimizer.step()
if ib % 2 == 0:
image = inputs[0].data.cpu()
image[0] = image[0] + 122.67891434
image[1] = image[1] + 116.66876762
image[2] = image[2] + 104.00698793
title = 'input (epoch: %d, step: %d)' % (it, ib)
vis.image(image, win=win1, env='fcn', opts=dict(title=title))
title = 'output (epoch: %d, step: %d)' % (it, ib)
vis.image(color_transform(outputs[0].cpu().max(0)[1].data),
win=win2, env='fcn', opts=dict(title=title))
title = 'target (epoch: %d, step: %d)' % (it, ib)
vis.image(color_transform(targets.cpu().data),
win=win3, env='fcn', opts=dict(title=title))
average = sum(epoch_loss) / len(epoch_loss)
print('loss: %.4f (epoch: %d, step: %d)' % (loss.data[0], it, ib))
epoch_loss.append(average)
x = np.arange(1, len(epoch_loss) + 1, 1)
title = 'loss (epoch: %d, step: %d)' % (it, ib)
vis.line(np.array(epoch_loss), x, env='fcn', win=win0,
opts=dict(title=title))
filename = ('%s/FCN-epoch-%d-step-%d.pth' \
% (checkRoot, it, ib))
torch.save(model.state_dict(), filename)
print('save: (epoch: %d, step: %d)' % (it, ib))
else:
for ib, data in enumerate(loader):
print('testing batch %d' % ib)
inputs = Variable(data[0]).cuda()
outputs = model(inputs)
hhh = color_transform(outputs[0].cpu().max(0)[1].data)
imsave(os.path.join(outputRoot, data[1][0] + '.png'), hhh)