-
Notifications
You must be signed in to change notification settings - Fork 1
/
bert.py
122 lines (93 loc) · 2.86 KB
/
bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from pytorch_transformers import BertForNextSentencePrediction
from pytorch_transformers import BertTokenizer
from torch.nn.functional import cosine_similarity
from torch.nn.functional import softmax
from torch.nn.utils.rnn import pad_sequence
BERT_MODEL_VERSION = 'bert-base-uncased'
MAX_SENTENCE_LENGTH = 512
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_VERSION)
model = BertForNextSentencePrediction.from_pretrained(
BERT_MODEL_VERSION,
output_hidden_states=True,
)
model.eval()
if torch.cuda.is_available():
model.cuda()
def calculate_similarities(
query_embedding,
document_embeddings,
):
return cosine_similarity(
query_embedding,
document_embeddings,
dim=1,
)
def calculate_next_sentence_probability(query, sentences):
all_indexed_tokens = []
tokenized_query = tokenizer.tokenize(query)
max_sentence_length = (
MAX_SENTENCE_LENGTH -
3 -
len(tokenized_query)
)
for sentence in sentences:
tokenized_text = (
[tokenizer.cls_token] +
tokenizer.tokenize(query) +
[tokenizer.sep_token] +
tokenizer.tokenize(sentence)[:max_sentence_length] +
[tokenizer.sep_token]
)
indexed_tokens = tokenizer.convert_tokens_to_ids(
tokenized_text,
)
all_indexed_tokens.append(torch.tensor(indexed_tokens))
tokens_tensor = pad_sequence(
all_indexed_tokens,
batch_first=True,
)
attention_mask = torch.where(
tokens_tensor != 0,
torch.ones(tokens_tensor.shape),
torch.zeros(tokens_tensor.shape),
)
if torch.cuda.is_available():
tokens_tensor = tokens_tensor.cuda()
attention_mask = attention_mask.cuda()
with torch.no_grad():
outputs = model(
tokens_tensor,
attention_mask=attention_mask,
)
return softmax(outputs[0], dim=1)[:, 0]
def embed_sentences(sentences):
all_indexed_tokens = []
for text in sentences:
tokenized_text = (
[tokenizer.cls_token] +
tokenizer.tokenize(text[:MAX_SENTENCE_LENGTH - 2]) +
[tokenizer.sep_token]
)
indexed_tokens = tokenizer.convert_tokens_to_ids(
tokenized_text,
)
all_indexed_tokens.append(torch.tensor(indexed_tokens))
tokens_tensor = pad_sequence(
all_indexed_tokens,
batch_first=True,
)
attention_mask = torch.where(
tokens_tensor != 0,
torch.ones(tokens_tensor.shape),
torch.zeros(tokens_tensor.shape),
)
if torch.cuda.is_available():
tokens_tensor = tokens_tensor.cuda()
attention_mask = attention_mask.cuda()
with torch.no_grad():
output = model.bert(
tokens_tensor,
attention_mask=attention_mask,
)
return output[2][-11].mean(dim=1)