Skip to content

Commit

Permalink
make sure xl and nar can handle mixture of softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 25, 2024
1 parent c0b0267 commit 2fb096e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.37.3',
version = '1.37.4',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
4 changes: 3 additions & 1 deletion x_transformers/nonautoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,11 @@ def forward(
with context():
logits = self.net(masked, **kwargs)

loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss

# cross entropy loss

loss = F.cross_entropy(
loss = loss_fn(
logits[mask],
orig_seq[mask]
)
Expand Down
10 changes: 6 additions & 4 deletions x_transformers/xl_autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def generate(
eos_token = None,
temperature = 1.,
filter_logits_fn = top_k,
filter_thres = 0.9,
filter_kwargs: dict = dict(),
mems = None,
**kwargs
):
Expand Down Expand Up @@ -88,7 +88,7 @@ def generate(
mems = cache.mems

logits = logits[:, -1]
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
filtered_logits = filter_logits_fn(logits, **filter_kwargs)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)
Expand Down Expand Up @@ -131,7 +131,9 @@ def forward(

split_x = x.split(max_seq_len, dim = -1)
split_labels = labels.split(max_seq_len, dim = -1)
loss_weights = tuple(map(lambda t: t.shape[-1] / seq_len, split_x))
loss_weights = tuple((t.shape[-1] / seq_len) for t in split_x)

loss_fn = F.cross_entropy if not self.net.is_log_prob else F.nll_loss

# go through each chunk and derive weighted losses

Expand All @@ -146,7 +148,7 @@ def forward(
**kwargs
)

loss = F.cross_entropy(
loss = loss_fn(
rearrange(logits, 'b n c -> b c n'),
chunk_labels,
ignore_index = ignore_index
Expand Down

0 comments on commit 2fb096e

Please sign in to comment.