forked from wkentaro/pytorch-fcn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate.py
executable file
·94 lines (81 loc) · 3.09 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python
import argparse
import os
import os.path as osp
import fcn
import numpy as np
import skimage.io
import torch
from torch.autograd import Variable
import torchfcn
import tqdm
def main():
parser = argparse.ArgumentParser()
parser.add_argument('model_file', help='Model path')
parser.add_argument('-g', '--gpu', type=int, default=0)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
model_file = args.model_file
root = osp.expanduser('~/data/datasets')
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False,
num_workers=4, pin_memory=True)
n_class = len(val_loader.dataset.class_names)
if osp.basename(model_file).startswith('fcn32s'):
model = torchfcn.models.FCN32s(n_class=21)
elif osp.basename(model_file).startswith('fcn16s'):
model = torchfcn.models.FCN16s(n_class=21)
elif osp.basename(model_file).startswith('fcn8s'):
if osp.basename(model_file).startswith('fcn8s-atonce'):
model = torchfcn.models.FCN8sAtOnce(n_class=21)
else:
model = torchfcn.models.FCN8s(n_class=21)
else:
raise ValueError
if torch.cuda.is_available():
model = model.cuda()
print('==> Loading %s model file: %s' %
(model.__class__.__name__, model_file))
model_data = torch.load(model_file)
try:
model.load_state_dict(model_data)
except Exception:
model.load_state_dict(model_data['model_state_dict'])
model.eval()
print('==> Evaluating with VOC2011ClassSeg seg11valid')
visualizations = []
label_trues, label_preds = [], []
for batch_idx, (data, target) in tqdm.tqdm(enumerate(val_loader),
total=len(val_loader),
ncols=80, leave=False):
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
score = model(data)
imgs = data.data.cpu()
lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
lbl_true = target.data.cpu()
for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
img, lt = val_loader.dataset.untransform(img, lt)
label_trues.append(lt)
label_preds.append(lp)
if len(visualizations) < 9:
viz = fcn.utils.visualize_segmentation(
lbl_pred=lp, lbl_true=lt, img=img, n_class=n_class,
label_names=val_loader.dataset.class_names)
visualizations.append(viz)
metrics = torchfcn.utils.label_accuracy_score(
label_trues, label_preds, n_class=n_class)
metrics = np.array(metrics)
metrics *= 100
print('''\
Accuracy: {0}
Accuracy Class: {1}
Mean IU: {2}
FWAV Accuracy: {3}'''.format(*metrics))
viz = fcn.utils.get_tile_image(visualizations)
skimage.io.imsave('viz_evaluate.png', viz)
if __name__ == '__main__':
main()