-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_scanpath.py
208 lines (168 loc) · 8.55 KB
/
plot_scanpath.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from os import makedirs, path, listdir
from random import randint
from skimage import io, transform
import json
import argparse
import sys
""" Usage:
To plot a model's scanpath on a given image:
plot_scanpath.py -dataset <dataset_name> -img <image_name> -model <model_name>
To plot a (random) human subject's scanpath on a given image:
plot_scanpath.py -dataset <dataset_name> -img <image_name> -human
"""
""" The main method of this script (plot_scanpath) belongs to https://github.com/cvlab-stonybrook/Scanpath_Prediction/plot_scanpath.py """
DATASETS_DIR = '../Datasets'
RESULTS_DIR = '../Results'
def plot_scanpath(img, xs, ys, fixation_size, bbox, title, save_path):
fig, ax = plt.subplots()
ax.imshow(img, cmap=plt.cm.gray)
initial_color = 'red'
scanpath_color = 'yellow'
for i in range(len(xs)):
if i > 0:
plt.arrow(xs[i - 1], ys[i - 1], xs[i] - xs[i - 1], ys[i] - ys[i - 1], width=3, color=scanpath_color, alpha=0.5)
for i in range(len(xs)):
if i == 0:
face_color = initial_color
else:
face_color = scanpath_color
circle = plt.Circle((xs[i], ys[i]),
radius=fixation_size[1] // 2,
edgecolor='red',
facecolor=face_color,
alpha=0.5)
ax.add_patch(circle)
plt.annotate("{}".format(i + 1), xy=(xs[i], ys[i] + 3), fontsize=10, ha="center", va="center")
# Draw target's bbox
rect = Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], alpha=0.7, edgecolor='red', facecolor='none', linewidth=2)
ax.add_patch(rect)
# To draw grid, useful for plotting nnIBS's scanpaths
# box_size = 32
# box_x = 0
# box_y = 0
# rows = round(img.shape[0] / box_size)
# columns = round(img.shape[1] / box_size)
# for row in range(rows):
# box_y = box_size * row
# for column in range(columns):
# box_x = box_size * column
# rect = Rectangle((box_x, box_y), box_size, box_size, alpha=0.5, edgecolor='yellow', facecolor='none', linewidth=2)
# ax.add_patch(rect)
ax.axis('off')
ax.set_title(title)
plt.savefig(path.join(save_path, title + '.png'))
plt.show()
plt.close()
def parse_args():
parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group()
group.add_argument('-model', type=str, help='Name of the visual search model')
group.add_argument('-human', nargs='?', const=True, default=False, help='ID of human subject to plot; leave blank to plot a scanpath generated by a random subject (who has found the target)')
parser.add_argument('-dataset', type=str, help='Name of the dataset')
parser.add_argument('-img', type=str, help='Name of the image on which to draw the scanpath (write \'notfound\' to plot target not found images')
args = parser.parse_args()
return args
def get_trial_info(image_name, trials_properties):
for trial in trials_properties:
if trial['image'] == image_name:
return trial
raise NameError('Image name must be in the dataset')
def rescale_coordinate(value, old_size, new_size, fixation_size=None, is_grid=False):
if is_grid:
# Rescale fixation to center of the cell in the grid
return value * fixation_size + (fixation_size // 2)
else:
return (value / old_size) * new_size
def load_dict_from_json(json_file_path):
if not path.exists(json_file_path):
return {}
else:
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def process_image(img_scanpath, subject, image_name, dataset_name, trial_info, images_path):
fixation_size = (img_scanpath['receptive_height'], img_scanpath['receptive_width'])
scanpath_img_size = (img_scanpath['image_height'], img_scanpath['image_width'])
image_file = path.join(images_path, image_name)
img = io.imread(image_file)
img_size_used = scanpath_img_size
original_img_size = img.shape[:2]
is_grid = False
# nnIBS uses a grid for images, it's necessary to upscale it
if 'IBS' in subject:
is_grid = True
img_size_used = (768, 1024)
fixation_size = (img_size_used[0] // scanpath_img_size[0], img_size_used[1] // scanpath_img_size[1])
img = transform.resize(img, img_size_used)
# Rescale scanpath if necessary
X = [rescale_coordinate(x, scanpath_img_size[1], img_size_used[1], fixation_size[1], is_grid) for x in img_scanpath['X']]
Y = [rescale_coordinate(y, scanpath_img_size[0], img_size_used[0], fixation_size[0], is_grid) for y in img_scanpath['Y']]
bbox = img_scanpath['target_bbox']
if is_grid:
bbox[0], bbox[2] = [rescale_coordinate(pos, original_img_size[0], scanpath_img_size[0], fixation_size[0], is_grid) for pos in (bbox[0], bbox[2])]
bbox[1], bbox[3] = [rescale_coordinate(pos, original_img_size[1], scanpath_img_size[1], fixation_size[1], is_grid) for pos in (bbox[1], bbox[3])]
target_height = bbox[2] - bbox[0]
target_width = bbox[3] - bbox[1]
bbox = [bbox[1], bbox[0], target_width, target_height]
save_path = path.join('Plots', path.join(dataset_name + '_dataset', image_name[:-4]))
if not path.exists(save_path):
makedirs(save_path)
title = image_name[:-4] + '_' + subject.replace(' ', '_')
plot_scanpath(img, X, Y, fixation_size, bbox, title, save_path)
if __name__ == '__main__':
args = parse_args()
if not args.human:
scanpaths_dir = path.join(path.join(RESULTS_DIR, args.dataset + '_dataset'), args.model)
if not path.exists(scanpaths_dir):
print('There are no results for ' + args.model + ' in the ' + args.dataset + ' dataset')
sys.exit(0)
scanpaths_file = path.join(scanpaths_dir, 'Scanpaths.json')
scanpaths = load_dict_from_json(scanpaths_file)
if args.img != 'notfound':
if not args.img in scanpaths:
print('Image not found in ' + args.model + ' scanpaths')
sys.exit(0)
img_scanpath = scanpaths[args.img]
subject = args.model
else:
human_scanpaths_dir = path.join(path.join(DATASETS_DIR, args.dataset), 'human_scanpaths')
if not path.exists(human_scanpaths_dir) or not listdir(human_scanpaths_dir):
print('There are no human subjects scanpaths for this dataset')
sys.exit(0)
human_scanpaths_files = listdir(human_scanpaths_dir)
number_of_subjects = len(human_scanpaths_files)
if isinstance(args.human, str):
human_subject = int(args.human) - 1
else:
human_subject = randint(0, number_of_subjects - 1)
human_scanpaths_files.sort()
target_found = False
checked_subjects = []
while not target_found:
scanpaths_file = path.join(human_scanpaths_dir, human_scanpaths_files[human_subject])
scanpaths = load_dict_from_json(scanpaths_file)
if args.img in scanpaths:
img_scanpath = scanpaths[args.img]
target_found = img_scanpath['target_found']
if not target_found:
checked_subjects.append(human_subject)
if len(checked_subjects) == number_of_subjects or isinstance(args.human, str):
print('No successful trial has been found for image ' + args.img)
sys.exit(0)
human_subject = randint(0, number_of_subjects - 1)
while human_subject in checked_subjects:
human_subject = randint(0, number_of_subjects - 1)
subject = 'Human subject ' + human_scanpaths_files[human_subject][4:6]
dataset_path = path.join(DATASETS_DIR, args.dataset)
dataset_info = load_dict_from_json(path.join(dataset_path, 'dataset_info.json'))
images_path = path.join(dataset_path, dataset_info['images_dir'])
trials_properties_file = path.join(dataset_path, 'trials_properties.json')
trials_properties = load_dict_from_json(trials_properties_file)
trial_info = get_trial_info(args.img, trials_properties)
if args.img == 'notfound' and not args.human:
for image_name in scanpaths.keys():
if not scanpaths[image_name]['target_found']:
process_image(scanpaths[image_name], subject, image_name, args.dataset, trial_info, images_path)
else:
process_image(img_scanpath, subject, args.img, args.dataset, trial_info, images_path)