Skip to content

Commit

Permalink
allow for setting embedding dimension to be different than model dime…
Browse files Browse the repository at this point in the history
…nsion
  • Loading branch information
lucidrains committed Mar 2, 2021
1 parent 6b93c21 commit c40815f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.8.3',
version = '0.8.4',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 8 additions & 2 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def __init__(
num_tokens,
max_seq_len,
attn_layers,
emb_dim = None,
max_mem_len = 0.,
emb_dropout = 0.,
num_memory_tokens = None,
Expand All @@ -617,13 +618,16 @@ def __init__(
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'

dim = attn_layers.dim
emb_dim = default(emb_dim, dim)

self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len

self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)

self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)

Expand Down Expand Up @@ -660,6 +664,8 @@ def forward(
x += self.pos_emb(x)
x = self.emb_dropout(x)

x = self.project_emb(x)

if num_mem > 0:
mem = repeat(self.memory_tokens, 'n d -> b n d', b = b)
x = torch.cat((mem, x), dim = 1)
Expand Down

0 comments on commit c40815f

Please sign in to comment.