-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathutils.py
478 lines (394 loc) · 17.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""
Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023
Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora
GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition
Project Website: https://abdur75648.github.io/UTRNet/
Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial
4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import pytz
import torch
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
from torch.autograd import Variable
import os,random,shutil
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
class CTCLabelConverter(object):
""" Convert between text-label and text-index """
def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
self.dict[char] = i + 1
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
batch_max_length: max length of text label in the batch. 25 by default
output:
text: text index for CTCLoss. [batch_size, batch_max_length]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
# The index used for padding (=0) would not affect the CTC loss calculation.
batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
for i, t in enumerate(text):
text = list(t)
text = [self.dict[char] for char in text]
batch_text[i][:len(text)] = torch.LongTensor(text)
return (batch_text, torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
for index, l in enumerate(length):
t = text_index[index, :]
char_list = []
for i in range(l):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
return texts
class CTCLabelConverterForBaiduWarpctc(object):
""" Convert between text-label and text-index for baidu warpctc """
def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
self.dict[char] = i + 1
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
text = ''.join(text)
text = [self.dict[char] for char in text]
return (torch.IntTensor(text), torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]
char_list = []
for i in range(l):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
index += l
return texts
class AttnLabelConverter(object):
""" Convert between text-label and text-index """
def __init__(self, character):
# character (str): set of the possible characters.
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]']
list_character = list(character)
self.character = list_token + list_character
self.dict = {}
for i, char in enumerate(self.character):
# print(i, char)
self.dict[char] = i
def encode(self, text, batch_max_length=25):
""" convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
batch_max_length: max length of text label in the batch. 25 by default
output:
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token.
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token.
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size]
"""
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
# batch_max_length = max(length) # this is not allowed for multi-gpu setting
batch_max_length += 1
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
for i, t in enumerate(text):
text = list(t)
text.append('[s]')
try:
text = [self.dict[char] for char in text]
except KeyError as e:
continue
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token
return (batch_text, torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
for index, l in enumerate(length):
text = ''.join([self.character[i] for i in text_index[index, :]])
texts.append(text)
return texts
def imshow(img, title,batch_size=1):
std_correction = np.asarray([0.229, 0.224, 0.225]).reshape(3, 1, 1)
mean_correction = np.asarray([0.485, 0.456, 0.406]).reshape(3, 1, 1)
npimg = np.multiply(img.numpy(), std_correction) + mean_correction
plt.figure(figsize = (batch_size * 4, 4))
plt.axis("off")
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.title(title)
plt.show()
class Averager(object):
"""Compute average for torch.Tensor, used for loss average."""
def __init__(self):
self.reset()
def add(self, v):
count = v.data.numel()
v = v.data.sum()
self.n_count += count
self.sum += v
def reset(self):
self.n_count = 0
self.sum = 0
def val(self):
res = 0
if self.n_count != 0:
res = self.sum / float(self.n_count)
return res
class Logger(object):
"""For logging while training"""
def __init__(self, path):
self.logFile = path
datetime_now = str(datetime.now(pytz.timezone('Asia/Kolkata')).strftime("%Y-%m-%d_%H-%M-%S"))
with open(self.logFile,"w",encoding="utf-8") as f:
f.write("Logging at @ " + str(datetime_now) + "\n")
def log(self,*input):
message = ""
for x in input:
message+=str(x) + " "
message = message.strip()
print(message)
with open(self.logFile,"a",encoding="utf-8") as f:
f.write(str(message)+"\n")
def allign_two_strings(x:str, y:str, pxy:int=1, pgap:int=1):
"""
Source: https://www.geeksforgeeks.org/sequence-alignment-problem/
"""
i = 0
j = 0
m = len(x)
n = len(y)
dp = np.zeros([m+1,n+1], dtype=int)
dp[0:(m+1),0] = [ i * pgap for i in range(m+1)]
dp[0,0:(n+1)] = [ i * pgap for i in range(n+1)]
i = 1
while i <= m:
j = 1
while j <= n:
if x[i - 1] == y[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = min(dp[i - 1][j - 1] + pxy,
dp[i - 1][j] + pgap,
dp[i][j - 1] + pgap)
j += 1
i += 1
l = n + m
i = m
j = n
xpos = l
ypos = l
xans = np.zeros(l+1, dtype=int)
yans = np.zeros(l+1, dtype=int)
while not (i == 0 or j == 0):
#print(f"i: {i}, j: {j}")
if x[i - 1] == y[j - 1]:
xans[xpos] = ord(x[i - 1])
yans[ypos] = ord(y[j - 1])
xpos -= 1
ypos -= 1
i -= 1
j -= 1
elif (dp[i - 1][j - 1] + pxy) == dp[i][j]:
xans[xpos] = ord(x[i - 1])
yans[ypos] = ord(y[j - 1])
xpos -= 1
ypos -= 1
i -= 1
j -= 1
elif (dp[i - 1][j] + pgap) == dp[i][j]:
xans[xpos] = ord(x[i - 1])
yans[ypos] = ord('_')
xpos -= 1
ypos -= 1
i -= 1
elif (dp[i][j - 1] + pgap) == dp[i][j]:
xans[xpos] = ord('_')
yans[ypos] = ord(y[j - 1])
xpos -= 1
ypos -= 1
j -= 1
while xpos > 0:
if i > 0:
i -= 1
xans[xpos] = ord(x[i])
xpos -= 1
else:
xans[xpos] = ord('_')
xpos -= 1
while ypos > 0:
if j > 0:
j -= 1
yans[ypos] = ord(y[j])
ypos -= 1
else:
yans[ypos] = ord('_')
ypos -= 1
id = 1
i = l
while i >= 1:
if (chr(yans[i]) == '_') and chr(xans[i]) == '_':
id = i + 1
break
i -= 1
i = id
x_seq = ""
while i <= l:
x_seq += chr(xans[i])
i += 1
# Y
i = id
y_seq = ""
while i <= l:
y_seq += chr(yans[i])
i += 1
return x_seq, y_seq
# Function to count the number of trainable parameters in a model in "Millions"
def count_parameters(model,precision=2):
return (round(sum(p.numel() for p in model.parameters() if p.requires_grad) / 10.**6, precision))
'''
# Code for counting the number of FLOPs in the CNN backbone during inference
Source - https://github.com/fdbtrs/ElasticFace/blob/main/utils/countFLOPS.py
'''
def count_model_flops(model,in_channels=1, input_res=[32, 400], multiply_adds=True):
list_conv = []
def conv_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
bias_ops = 1 if self.bias is not None else 0
params = output_channels * (kernel_ops + bias_ops)
flops = (kernel_ops * (
2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size
list_conv.append(flops)
list_linear = []
def linear_hook(self, input, output):
batch_size = input[0].size(0) if input[0].dim() == 2 else 1
weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
if self.bias is not None:
bias_ops = self.bias.nelement() if self.bias.nelement() else 0
flops = batch_size * (weight_ops + bias_ops)
else:
flops = batch_size * weight_ops
list_linear.append(flops)
list_bn = []
def bn_hook(self, input, output):
list_bn.append(input[0].nelement() * 2)
list_relu = []
def relu_hook(self, input, output):
list_relu.append(input[0].nelement())
list_pooling = []
def pooling_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
# If kernel_size is a tuple type, computer ops as product of elements or else if it is int type, compute ops as square of kernel_size
kernel_ops = self.kernel_size[0] * self.kernel_size[1] if isinstance(self.kernel_size, tuple) else self.kernel_size * self.kernel_size
bias_ops = 0
params = 0
flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size
list_pooling.append(flops)
def dropout_hook(self, input, output):
# calculate the number of operations for a dropout function by assuming that each operation involves one comparison and one multiplication
batch_size, input_channels, input_height, input_width = input[0].size()
list_conv.append(2*batch_size*input_channels*input_height*input_width)
def sigmoid_hook(self,input,output):
# calculate the number of operations for a sigmoid function by assuming that each operation involves two multiplications and one addition
batch_size, input_channels, input_height, input_width = input[0].size()
list_conv.append(3*batch_size*input_channels*input_height*input_width)
def upsample_hook(self, input, output):
batch_size, input_channels, input_height, input_width = input[0].size()
output_channels, output_height, output_width = output[0].size()
kernel_ops = self.scale_factor * self.scale_factor # * (self.in_channels / self.groups)
flops = (kernel_ops * (
2 if multiply_adds else 1)) * output_channels * output_height * output_width * batch_size
list_conv.append(flops)
handles = []
def foo(net):
childrens = list(net.children())
if not childrens:
if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d):
handles.append(net.register_forward_hook(conv_hook))
elif isinstance(net, torch.nn.Linear):
handles.append(net.register_forward_hook(linear_hook))
elif isinstance(net, torch.nn.BatchNorm2d) or isinstance(net, torch.nn.BatchNorm1d):
handles.append(net.register_forward_hook(bn_hook))
elif isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU):
handles.append(net.register_forward_hook(relu_hook))
elif isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
handles.append(net.register_forward_hook(pooling_hook))
elif isinstance(net, torch.nn.Dropout):
handles.append(net.register_forward_hook(dropout_hook))
elif isinstance(net,torch.nn.Upsample):
handles.append(net.register_forward_hook(upsample_hook))
elif isinstance(net,torch.nn.Sigmoid):
handles.append(net.register_forward_hook(sigmoid_hook))
else:
print("warning" + str(net))
return
for c in childrens:
foo(c)
model.eval()
foo(model)
input = Variable(torch.rand(in_channels, input_res[1], input_res[0]).unsqueeze(0), requires_grad=True)
out = model(input)
total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling))
for h in handles:
h.remove()
model.train()
def flops_to_string(flops, units='MFLOPS', precision=4):
if units == 'GFLOPS':
return str(round(flops / 10.**9, precision)) + ' ' + units
elif units == 'MFLOPS':
return str(round(flops / 10.**6, precision)) + ' ' + units
elif units == 'KFLOPS':
return str(round(flops / 10.**3, precision)) + ' ' + units
else:
return str(flops) + ' FLOPS'
return flops_to_string(total_flops)
def draw_feature_map(visual_feature,vis_dir,num_channel=10):
"""draws feature maps for the given visual features
Args:
visual_feature (Tensor): Shape (C, H, W)
vis_dir (String): Directory to save the feature maps
"""
if os.path.exists(vis_dir):
shutil.rmtree(vis_dir)
os.makedirs(vis_dir)
# Save visual_feature from num_channel random channels for visualization
for i in range(num_channel):
random_channel = random.randint(0, visual_feature.shape[1]-1)
visual_feature_for_visualization = visual_feature[0, random_channel, :, :].detach().cpu().numpy()
# Horizontal flip
visual_feature_for_visualization = visual_feature_for_visualization[:,::-1]
# Normalize
visual_feature_for_visualization = (visual_feature_for_visualization - visual_feature_for_visualization.min()) / (visual_feature_for_visualization.max() - visual_feature_for_visualization.min())
# Draw heatmap
plt.imshow(visual_feature_for_visualization, cmap='gray', interpolation='nearest')
plt.axis("off")
plt.savefig(os.path.join(vis_dir, "channel_{}.png".format(random_channel)), bbox_inches='tight', pad_inches=0)