diff --git a/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py b/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py index 80396df6..9ca26442 100644 --- a/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py +++ b/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py @@ -103,17 +103,7 @@ def forward(self, value, key, query, mask): class Encoder(nn.Module): - def __init__( - self, - src_vocab_size, - embed_size, - num_layers, - heads, - device, - forward_expansion, - dropout, - max_length, - ): + def __init__(self,src_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length,): super(Encoder, self).__init__() self.embed_size = embed_size @@ -122,25 +112,15 @@ def __init__( self.position_embedding = nn.Embedding(max_length, embed_size) self.layers = nn.ModuleList( - [ - TransformerBlock( - embed_size, - heads, - dropout=dropout, - forward_expansion=forward_expansion, - ) - for _ in range(num_layers) - ] - ) + [TransformerBlock(embed_size, heads, dropout=dropout, forward_expansion=forward_expansion,) + for _ in range(num_layers)]) self.dropout = nn.Dropout(dropout) def forward(self, x, mask): N, seq_length = x.shape positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device) - out = self.dropout( - (self.word_embedding(x) + self.position_embedding(positions)) - ) + out = self.dropout((self.word_embedding(x) + self.position_embedding(positions))) # In the Encoder the query, key, value are all the same, it's in the # decoder this will change. This might look a bit odd in this case. @@ -168,28 +148,15 @@ def forward(self, x, value, key, src_mask, trg_mask): class Decoder(nn.Module): - def __init__( - self, - trg_vocab_size, - embed_size, - num_layers, - heads, - forward_expansion, - dropout, - device, - max_length, - ): + def __init__(self, trg_vocab_size, embed_size, num_layers, heads, forward_expansion,dropout, device, max_length,): super(Decoder, self).__init__() self.device = device self.word_embedding = nn.Embedding(trg_vocab_size, embed_size) self.position_embedding = nn.Embedding(max_length, embed_size) self.layers = nn.ModuleList( - [ - DecoderBlock(embed_size, heads, forward_expansion, dropout, device) - for _ in range(num_layers) - ] - ) + [DecoderBlock(embed_size, heads, forward_expansion, dropout, device) + for _ in range(num_layers)]) self.fc_out = nn.Linear(embed_size, trg_vocab_size) self.dropout = nn.Dropout(dropout) @@ -223,28 +190,8 @@ def __init__( ): super(Transformer, self).__init__() - - self.encoder = Encoder( - src_vocab_size, - embed_size, - num_layers, - heads, - device, - forward_expansion, - dropout, - max_length, - ) - - self.decoder = Decoder( - trg_vocab_size, - embed_size, - num_layers, - heads, - forward_expansion, - dropout, - device, - max_length, - ) + self.encoder = Encoder(src_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length,) + self.decoder = Decoder(trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length,) self.src_pad_idx = src_pad_idx self.trg_pad_idx = trg_pad_idx @@ -257,10 +204,7 @@ def make_src_mask(self, src): def make_trg_mask(self, trg): N, trg_len = trg.shape - trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand( - N, 1, trg_len, trg_len - ) - + trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len) return trg_mask.to(self.device) def forward(self, src, trg): @@ -275,17 +219,15 @@ def forward(self, src, trg): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) - x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to( - device - ) + x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0],[1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device) trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device) src_pad_idx = 0 trg_pad_idx = 0 src_vocab_size = 10 trg_vocab_size = 10 - model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to( - device - ) + + model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device) + out = model(x, trg[:, :-1]) print(out.shape)