-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcrf.py
176 lines (141 loc) · 6.97 KB
/
crf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# -*- encoding: utf-8 -*-
'''
@File : crf.py
@Time : 2019/11/23 17:35:36
@Author : Cao Shuai
@Version : 1.0
@Contact : [email protected]
@License : (C)Copyright 2018-2019, MILAB_SCU
@Desc : None
'''
import torch
import torch.nn as nn
from transformers import BertModel
from bert_or_thesues import BertModelThesues
def argmax(vec):
# return the argmax as a python int
_, idx = torch.max(vec, 1)
return idx.item()
# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
max_score = vec[0, argmax(vec)]
max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
return max_score + \
torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
def log_sum_exp_batch(log_Tensor, axis=-1): # shape (batch_size,n,m)
return torch.max(log_Tensor, axis)[0] + \
torch.log(torch.exp(log_Tensor-torch.max(log_Tensor, axis)[0].view(log_Tensor.shape[0],-1,1)).sum(axis))
class Bert_BiLSTM_CRF(nn.Module):
def __init__(self, tag_to_ix, model_dir='bert-base-chinese',
hidden_dim=768, bert_thesues=False, fine_tune_scc=False, scc_layer=6, device='cpu'):
super(Bert_BiLSTM_CRF, self).__init__()
self.tag_to_ix = tag_to_ix
self.tagset_size = len(tag_to_ix)
# self.hidden = self.init_hidden()
self.lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=hidden_dim, hidden_size=hidden_dim//2, batch_first=True)
self.transitions = nn.Parameter(torch.randn(
self.tagset_size, self.tagset_size
))
self.hidden_dim = hidden_dim
self.start_label_id = self.tag_to_ix['[CLS]']
self.end_label_id = self.tag_to_ix['[SEP]']
self.fc = nn.Linear(hidden_dim, self.tagset_size)
# self.bert = BertModel.from_pretrained('/root/workspace/qa_project/chinese_L-12_H-768_A-12')
if not bert_thesues:
self.bert = BertModel.from_pretrained(model_dir)
else:
self.bert = BertModelThesues.from_pretrained(model_dir, fine_tune_scc=fine_tune_scc, scc_layer=scc_layer)
# self.bert.eval() # 知用来取bert embedding
self.transitions.data[self.start_label_id, :] = -10000
self.transitions.data[:, self.end_label_id] = -10000
self.device = torch.device(device)
# self.transitions.to(self.device)
def init_hidden(self):
return (torch.randn(2, 1, self.hidden_dim // 2),
torch.randn(2, 1, self.hidden_dim // 2))
def _forward_alg(self, feats):
'''
this also called alpha-recursion or forward recursion, to calculate log_prob of all barX
'''
# T = self.max_seq_length
T = feats.shape[1]
batch_size = feats.shape[0]
# alpha_recursion,forward, alpha(zt)=p(zt,bar_x_1:t)
log_alpha = torch.Tensor(batch_size, 1, self.tagset_size).fill_(-10000.).to(self.device) #[batch_size, 1, 16]
# normal_alpha_0 : alpha[0]=Ot[0]*self.PIs
# self.start_label has all of the score. it is log,0 is p=1
log_alpha[:, 0, self.start_label_id] = 0
# feats: sentances -> word embedding -> lstm -> MLP -> feats
# feats is the probability of emission, feat.shape=(1,tag_size)
for t in range(1, T):
log_alpha = (log_sum_exp_batch(self.transitions + log_alpha, axis=-1) + feats[:, t]).unsqueeze(1)
# log_prob of all barX
log_prob_all_barX = log_sum_exp_batch(log_alpha)
return log_prob_all_barX
def _score_sentence(self, feats, label_ids):
T = feats.shape[1]
batch_size = feats.shape[0]
batch_transitions = self.transitions.expand(batch_size,self.tagset_size,self.tagset_size)
batch_transitions = batch_transitions.flatten(1)
score = torch.zeros((feats.shape[0],1)).to(self.device)
# the 0th node is start_label->start_word,the probability of them=1. so t begin with 1.
for t in range(1, T):
score = score + \
batch_transitions.gather(-1, (label_ids[:, t]*self.tagset_size+label_ids[:, t-1]).view(-1,1)) \
+ feats[:, t].gather(-1, label_ids[:, t].view(-1,1)).view(-1,1)
return score
def _bert_enc(self, x):
"""
x: [batchsize, sent_len]
enc: [batch_size, sent_len, 768]
"""
with torch.no_grad():
encoded_layer, _ = self.bert(x)
# enc = encoded_layer[-1]
return encoded_layer
def _viterbi_decode(self, feats):
'''
Max-Product Algorithm or viterbi algorithm, argmax(p(z_0:t|x_0:t))
'''
# T = self.max_seq_length
T = feats.shape[1]
batch_size = feats.shape[0]
# batch_transitions=self.transitions.expand(batch_size,self.tagset_size,self.tagset_size)
log_delta = torch.Tensor(batch_size, 1, self.tagset_size).fill_(-10000.).to(self.device)
log_delta[:, 0, self.start_label_id] = 0.
# psi is for the vaule of the last latent that make P(this_latent) maximum.
psi = torch.zeros((batch_size, T, self.tagset_size), dtype=torch.long) # psi[0]=0000 useless
for t in range(1, T):
# delta[t][k]=max_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
# delta[t] is the max prob of the path from z_t-1 to z_t[k]
log_delta, psi[:, t] = torch.max(self.transitions + log_delta, -1)
# psi[t][k]=argmax_z1:t-1( p(x1,x2,...,xt,z1,z2,...,zt-1,zt=k|theta) )
# psi[t][k] is the path choosed from z_t-1 to z_t[k],the value is the z_state(is k) index of z_t-1
log_delta = (log_delta + feats[:, t]).unsqueeze(1)
# trace back
path = torch.zeros((batch_size, T), dtype=torch.long)
# max p(z1:t,all_x|theta)
max_logLL_allz_allx, path[:, -1] = torch.max(log_delta.squeeze(), -1)
for t in range(T-2, -1, -1):
# choose the state of z_t according the state choosed of z_t+1.
path[:, t] = psi[:, t+1].gather(-1,path[:, t+1].view(-1,1)).squeeze()
return max_logLL_allz_allx, path
def neg_log_likelihood(self, sentence, tags):
feats = self._get_lstm_features(sentence) #[batch_size, max_len, 16]
forward_score = self._forward_alg(feats)
gold_score = self._score_sentence(feats, tags)
return torch.mean(forward_score - gold_score)
def _get_lstm_features(self, sentence):
"""sentence is the ids"""
# self.hidden = self.init_hidden()
embeds = self._bert_enc(sentence) # [8, 75, 768]
# 过lstm
enc, _ = self.lstm(embeds)
lstm_feats = self.fc(enc)
return lstm_feats # [8, 75, 16]
def forward(self, sentence): # dont confuse this with _forward_alg above.
# Get the emission scores from the BiLSTM
lstm_feats = self._get_lstm_features(sentence) # [8, 180,768]
# Find the best path, given the features.
score, tag_seq = self._viterbi_decode(lstm_feats)
return score, tag_seq