-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEncoderDecoder.py
41 lines (36 loc) · 1.53 KB
/
EncoderDecoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Optional
import paddle
import paddle.nn as nn
from paddle import Tensor
from Decoder import Decoder
from Embedding import TransformerEmbedding, PositionalEncoding
from Encoder import Encoder
class EncoderDecoder(nn.Layer):
def __init__(self, vocab_size: int, d_model: int = 512):
super(EncoderDecoder, self).__init__()
self.layers_nums = 3
self.embedding = nn.Sequential(
TransformerEmbedding(vocab_size),
PositionalEncoding()
)
self.encoder = Encoder(self.layers_nums)
self.decoder = Decoder(self.layers_nums)
self.linear = nn.Linear(d_model, vocab_size)
self.soft_max = nn.Softmax()
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
def forward(self, x, label, true_label: Optional[Tensor] = None, src_mask=None, tgt_mask=None):
input_embedding = self.embedding(x)
label_embedding = self.embedding(label)
encoder_output = self.encoder(input_embedding, src_mask)
decoder_output = self.decoder(label_embedding, encoder_output, src_mask, tgt_mask)
logits = self.linear(decoder_output)
res_dict = {}
if true_label is not None:
loss = self.loss_fct(logits.reshape((-1, logits.shape[-1])),
true_label.reshape((-1,)))
res_dict['loss'] = loss
result = self.soft_max(logits)
max_index = paddle.argmax(result, axis=-1)
res_dict['logits'] = result
res_dict['index'] = max_index
return res_dict