-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy patheval.py
126 lines (109 loc) · 5.03 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import cv2
import argparse
import numpy as np
import torch
import torch.nn as nn
from utils.pyt_utils import ensure_dir, link_file, load_model, parse_devices
from utils.visualize import print_iou, show_img
from engine.evaluator import Evaluator
from engine.logger import get_logger
from utils.metric import hist_info, compute_score
from dataloader.RGBXDataset import RGBXDataset
from models.builder import EncoderDecoder as segmodel
from dataloader.dataloader import ValPre
from PIL import Image
logger = get_logger()
class SegEvaluator(Evaluator):
def func_per_iteration(self, data, device, config):
img = data['data']
label = data['label']
modal_x = data['modal_x']
name = data['fn']
pred = self.sliding_eval_rgbX(img, modal_x, config.eval_crop_size, config.eval_stride_rate, device)
hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes, pred, label)
results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp, 'correct': correct_tmp}
if self.save_path is not None:
ensure_dir(self.save_path)
ensure_dir(self.save_path+'_color')
fn = name + '.png'
# save colored result
result_img = Image.fromarray(pred.astype(np.uint8), mode='P')
class_colors = self.dataset.get_class_colors()
palette_list = list(np.array(class_colors).flat)
if len(palette_list) < 768:
palette_list += [0] * (768 - len(palette_list))
result_img.putpalette(palette_list)
result_img.save(os.path.join(self.save_path+'_color', fn))
# save raw result
cv2.imwrite(os.path.join(self.save_path, fn), pred)
logger.info('Save the image ' + fn)
if self.show_image:
colors = self.dataset.get_class_colors()
image = img
clean = np.zeros(label.shape)
comp_img = show_img(colors, config.background, image, clean,
label,
pred)
cv2.imshow('comp_image', comp_img)
cv2.waitKey(0)
return results_dict
def compute_metric(self, results, config):
hist = np.zeros((config.num_classes, config.num_classes))
correct = 0
labeled = 0
count = 0
for d in results:
hist += d['hist']
correct += d['correct']
labeled += d['labeled']
count += 1
iou, mean_IoU, _, freq_IoU, mean_pixel_acc, pixel_acc = compute_score(hist, correct, labeled)
result_line = print_iou(iou, freq_IoU, mean_pixel_acc, pixel_acc,
self.dataset.class_names, show_no_back=False)
return result_line, mean_IoU
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', default='last', type=str)
parser.add_argument('-d', '--devices', default='0', type=str)
parser.add_argument('-v', '--verbose', default=False, action='store_true')
parser.add_argument('--show_image', '-s', default=False,
action='store_true')
parser.add_argument('--save_path', '-p', default=None)
parser.add_argument('--dataset_name', '-n', default='mfnet', type=str)
args = parser.parse_args()
all_dev = parse_devices(args.devices)
dataset_name = args.dataset_name
if dataset_name == 'mfnet':
from configs.config_MFNet import config
elif dataset_name == 'pst':
from configs.config_pst900 import config
elif dataset_name == 'nyu':
from configs.config_nyu import config
elif dataset_name == 'sun':
from configs.config_sunrgbd import config
else:
raise ValueError('Not a valid dataset name')
network = segmodel(cfg=config, criterion=None, norm_layer=nn.BatchNorm2d)
data_setting = {'rgb_root': config.rgb_root_folder,
'rgb_format': config.rgb_format,
'gt_root': config.gt_root_folder,
'gt_format': config.gt_format,
'transform_gt': config.gt_transform,
'x_root':config.x_root_folder,
'x_format': config.x_format,
'x_single_channel': config.x_is_single_channel,
'class_names': config.class_names,
'train_source': config.train_source,
'eval_source': config.eval_source,
'class_names': config.class_names}
val_pre = ValPre()
dataset = RGBXDataset(data_setting, 'val', val_pre)
with torch.no_grad():
segmentor = SegEvaluator(dataset, config.num_classes, config.norm_mean,
config.norm_std, network,
config.eval_scale_array, config.eval_flip,
all_dev, args.verbose, args.save_path,
args.show_image, config)
_, mean_IoU = segmentor.run_eval(config.checkpoint_dir, args.epochs, config.val_log_file,
config.link_val_log_file)