-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
71 lines (56 loc) · 2.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
Assignment 5: CRNN For Text Recognition
Course Coordinator: Dr. Manojkumar Ramteke
Teaching Assistant: Abdur Rahman
This code is for educational purposes only. Unauthorized copying or distribution without the consent of the course coordinator is prohibited.
Copyright © 2024. All rights reserved.
"""
import torch
class ConverterForCTC(object):
""" Convert between text-label and text-index """
def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
# NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss
self.dict[char] = i + 1
self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0)
def encode(self, text, batch_max_length=25):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
batch_max_length: max length of text label in the batch. 25 by default
output:
text: text index for CTCLoss. [batch_size, batch_max_length]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
# The index used for padding (=0) would not affect the CTC loss calculation.
batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0)
for i, t in enumerate(text):
text = list(t)
text = [self.dict[char] for char in text]
batch_text[i][:len(text)] = torch.LongTensor(text)
return (batch_text, torch.IntTensor(length))
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
for index, l in enumerate(length):
t = text_index[index, :]
char_list = []
for i in range(l):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
return texts
########## Example usage ##########
# character = '0123456789' # [CTCblank] is automatically added at the beginning
# converter = ConverterForCTC(character)
# text = ['123', '456'] # Batch Size = 2
# encoded_text, length = converter.encode(text, batch_max_length=5)
# print(encoded_text) # tensor([[2, 3, 4, 0, 0], [5, 6, 7, 0, 0]])
# print(length) # tensor([3, 3])
# decoded_text = converter.decode(encoded_text, length)
# print(decoded_text) # ['123', '456']