forked from yandexdataschool/nlp_course
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vocab.py
76 lines (66 loc) · 2.81 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
import numpy as np
import torch
import torch.nn.functional as F
class Vocab:
def __init__(self, tokens, bos="_BOS_", eos="_EOS_", unk='_UNK_'):
"""
A special class that converts lines of tokens into matrices and backwards
"""
assert all(tok in tokens for tok in (bos, eos, unk))
self.tokens = tokens
self.token_to_ix = {t:i for i, t in enumerate(tokens)}
self.bos, self.eos, self.unk = bos, eos, unk
self.bos_ix = self.token_to_ix[bos]
self.eos_ix = self.token_to_ix[eos]
self.unk_ix = self.token_to_ix[unk]
def __len__(self):
return len(self.tokens)
@staticmethod
def from_lines(lines, bos="_BOS_", eos="_EOS_", unk='_UNK_'):
flat_lines = '\n'.join(list(lines)).split()
tokens = sorted(set(flat_lines))
tokens = [t for t in tokens if t not in (bos, eos, unk) and len(t)]
tokens = [bos, eos, unk] + tokens
return Vocab(tokens, bos, eos, unk)
def tokenize(self, string):
"""converts string to a list of tokens"""
tokens = [tok if tok in self.token_to_ix else self.unk
for tok in string.split()]
return [self.bos] + tokens + [self.eos]
def to_matrix(self, lines, dtype=torch.int64, max_len=None):
"""
convert variable length token sequences into fixed size matrix
example usage:
>>>print(to_matrix(words[:3],source_to_ix))
[[15 22 21 28 27 13 -1 -1 -1 -1 -1]
[30 21 15 15 21 14 28 27 13 -1 -1]
[25 37 31 34 21 20 37 21 28 19 13]]
"""
lines = list(map(self.tokenize, lines))
max_len = max_len or max(map(len, lines))
matrix = torch.full((len(lines), max_len), self.eos_ix, dtype=dtype)
for i, seq in enumerate(lines):
row_ix = list(map(self.token_to_ix.get, seq))[:max_len]
matrix[i, :len(row_ix)] = torch.as_tensor(row_ix)
return matrix
def to_lines(self, matrix, crop=True):
"""
Convert matrix of token ids into strings
:param matrix: matrix of tokens of int32, shape=[batch,time]
:param crop: if True, crops BOS and EOS from line
:return:
"""
lines = []
for line_ix in map(list,matrix):
if crop:
if line_ix[0] == self.bos_ix:
line_ix = line_ix[1:]
if self.eos_ix in line_ix:
line_ix = line_ix[:line_ix.index(self.eos_ix)]
line = ' '.join(self.tokens[i] for i in line_ix)
lines.append(line)
return lines
def compute_mask(self, input_ix):
""" compute a boolean mask that equals "1" until first EOS (including that EOS) """
return F.pad(torch.cumsum(input_ix == self.eos_ix, dim=-1)[..., :-1] < 1, pad=(1, 0, 0, 0), value=True)