From 3b8992fe274f31df61464fc127cde9ffb34896e7 Mon Sep 17 00:00:00 2001 From: Redwan Karim Sony Date: Fri, 9 Jul 2021 22:56:46 +0600 Subject: [PATCH] updated Encoder and Decoder signature. Changed to the position of the variable device in Encoder to make sure it is the same as the Decoder so that it looks okay and does not create confusion. Thank you. --- .../transformer_from_scratch.py | 86 +++---------------- 1 file changed, 14 insertions(+), 72 deletions(-) diff --git a/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py b/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py index 80396df6..9ca26442 100644 --- a/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py +++ b/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py @@ -103,17 +103,7 @@ def forward(self, value, key, query, mask): class Encoder(nn.Module): - def __init__( - self, - src_vocab_size, - embed_size, - num_layers, - heads, - device, - forward_expansion, - dropout, - max_length, - ): + def __init__(self,src_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length,): super(Encoder, self).__init__() self.embed_size = embed_size @@ -122,25 +112,15 @@ def __init__( self.position_embedding = nn.Embedding(max_length, embed_size) self.layers = nn.ModuleList( - [ - TransformerBlock( - embed_size, - heads, - dropout=dropout, - forward_expansion=forward_expansion, - ) - for _ in range(num_layers) - ] - ) + [TransformerBlock(embed_size, heads, dropout=dropout, forward_expansion=forward_expansion,) + for _ in range(num_layers)]) self.dropout = nn.Dropout(dropout) def forward(self, x, mask): N, seq_length = x.shape positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device) - out = self.dropout( - (self.word_embedding(x) + self.position_embedding(positions)) - ) + out = self.dropout((self.word_embedding(x) + self.position_embedding(positions))) # In the Encoder the query, key, value are all the same, it's in the # decoder this will change. This might look a bit odd in this case. @@ -168,28 +148,15 @@ def forward(self, x, value, key, src_mask, trg_mask): class Decoder(nn.Module): - def __init__( - self, - trg_vocab_size, - embed_size, - num_layers, - heads, - forward_expansion, - dropout, - device, - max_length, - ): + def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion,dropout, device, max_length,): super(Decoder, self).__init__() self.device = device self.word_embedding = nn.Embedding(trg_vocab_size, embed_size) self.position_embedding = nn.Embedding(max_length, embed_size) self.layers = nn.ModuleList( - [ - DecoderBlock(embed_size, heads, forward_expansion, dropout, device) - for _ in range(num_layers) - ] - ) + [DecoderBlock(embed_size, heads, forward_expansion, dropout, device) + for _ in range(num_layers)]) self.fc_out = nn.Linear(embed_size, trg_vocab_size) self.dropout = nn.Dropout(dropout) @@ -223,28 +190,8 @@ def __init__( ): super(Transformer, self).__init__() - - self.encoder = Encoder( - src_vocab_size, - embed_size, - num_layers, - heads, - device, - forward_expansion, - dropout, - max_length, - ) - - self.decoder = Decoder( - trg_vocab_size, - embed_size, - num_layers, - heads, - forward_expansion, - dropout, - device, - max_length, - ) + self.encoder = Encoder(src_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length,) + self.decoder = Decoder(trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length,) self.src_pad_idx = src_pad_idx self.trg_pad_idx = trg_pad_idx @@ -257,10 +204,7 @@ def make_src_mask(self, src): def make_trg_mask(self, trg): N, trg_len = trg.shape - trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand( - N, 1, trg_len, trg_len - ) - + trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len) return trg_mask.to(self.device) def forward(self, src, trg): @@ -275,17 +219,15 @@ def forward(self, src, trg): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) - x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to( - device - ) + x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0],[1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device) trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device) src_pad_idx = 0 trg_pad_idx = 0 src_vocab_size = 10 trg_vocab_size = 10 - model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to( - device - ) + + model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device) + out = model(x, trg[:, :-1]) print(out.shape)