-
Notifications
You must be signed in to change notification settings - Fork 14
/
utils.py
145 lines (126 loc) · 4.71 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
import pickle
import random
import numpy as np
import torch
from PIL import Image
def cmdline_args_parser():
"""
Commandline argument parser
"""
parser = argparse.ArgumentParser("Cmdline Arguments")
parser.add_argument("-d", "--device", type=int, default=0)
parser.add_argument("-e", "--n_epochs", type=int, default=50)
parser.add_argument("-lr", "--learning_rate", type=float, default=0.0005)
parser.add_argument("-bs", "--batch_size", type=int, default=5)
parser.add_argument("-cs", "--context_size", type=int, default=12)
parser.add_argument("-hd", "--hidden_dim", type=int, default=384)
parser.add_argument("-r", "--roi", type=int, default=3)
parser.add_argument("-bbhd", "--bbox_hidden_dim", type=int, default=32)
parser.add_argument(
"--use_additional_feat", dest="additional_feat", action="store_true"
)
parser.add_argument("-wd", "--weight_decay", type=float, default=1e-3)
parser.add_argument("-dp", "--drop_prob", type=float, default=0.2)
parser.add_argument("-sf", "--sampling_fraction", type=float, default=0.9)
parser.add_argument("-nw", "--num_workers", type=int, default=5)
parser.add_argument(
"-cvf", "--cv_fold", type=int, required=True, choices=[-1, 1, 2, 3, 4, 5]
) # cvf=-1 means fold_dir is set to split_dir
return parser
def count_parameters(model):
"""
Return the number of trainable parameters in `model`
"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def pkl_load(file_path):
"""
Load a pickle file at filt_path
"""
return pickle.load(open(file_path, "rb"))
def print_and_log(msg, log_file, write_mode="a"):
"""
print `msg` (string) on stdout and also append ('a') or write ('w') (default 'a') it to `log_file`
"""
print(msg)
with open(log_file, write_mode) as f:
f.write(msg + "\n")
def set_all_seeds(seed=123):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
def visualize_bbox(img_path, attn_wt_file, img_save_dir):
"""
Plot img and show all context bboxes on the img with attention scores
Target BBox is bold red, context bbox is green with shade denoting score
attn_wt_file is a csv file containing 3 rows, 5 + 10*context_size cols
Each row contains plot data for a target class (Price, Title, Image)
Cols: 4 bbox coords, 1 label, 2*context_size*4 context bbox coords, 2*context_size attnetion values that sum to 1
Save 3 files corresponding to 3 classes in img_save_dir (must exist)
"""
import matplotlib.pyplot as plt
class_names = {0: "BG", 1: "Price", 2: "Title", 3: "Image"}
img = Image.open(img_path).convert("RGB")
plt_data = np.loadtxt(attn_wt_file, delimiter=",")
context_size = int((plt_data.shape[1] - 5) / 10)
plt_data[:, -2 * context_size :] /= plt_data[:, -2 * context_size :].max()
plt.rcParams.update({"font.size": 6})
for row in plt_data:
plt.imshow(img)
plt.title("Attention Visualization for class: " + class_names[int(row[4])])
ax = plt.gca()
ax.add_patch(
plt.Rectangle(
(row[0], row[1]),
row[2],
row[3],
fill=False,
edgecolor="#fa4772",
linewidth=1.5,
)
)
for c in range(1, 2 * context_size + 1):
if (
row[4 * c + 1] == 0
and row[4 * c + 2] == 0
and row[4 * c + 3] == 0
and row[4 * c + 4] == 0
):
continue
ax.add_patch(
plt.Rectangle(
(row[4 * c + 1], row[4 * c + 2]),
row[4 * c + 3],
row[4 * c + 4],
fill=True,
facecolor="#43a047",
alpha=0.75 * row[4 * (2 * context_size + 1) + c],
)
)
ax.add_patch(
plt.Rectangle(
(row[4 * c + 1], row[4 * c + 2]),
row[4 * c + 3],
row[4 * c + 4],
fill=False,
edgecolor="#43a047",
linewidth=0.75,
)
)
plt.axis("off")
plt.tight_layout()
plt.savefig(
"%s/%s_attn_%s.png"
% (
img_save_dir,
img_path.rsplit("/", 1)[-1][:-4],
class_names[int(row[4])],
),
dpi=300,
bbox_inches="tight",
pad_inches=0,
)
plt.close()