Skip to content

Commit

Permalink
consistent shape between logits and generated_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Oct 4, 2024
1 parent 4aa4ebd commit 8ce2970
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions exllamav2/generator/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,12 @@ def sample(
# Apply logits processor

if settings.logits_processor:
generated_ids = sequence_ids[:, input_ids.shape[1]:]
logits = settings.logits_processor(generated_ids, logits)
generated_ids = sequence_ids[:, input_ids.shape[-1]:]
# normalize to 2d
logits_2d = logits.view(-1, logits.shape[-1])
generated_ids_2d = generated_ids.view(logits_2d.shape[0], generated_ids.shape[-1])
# process logits and convert back to original logits shape
logits = settings.logits_processor(generated_ids_2d, logits_2d).view(logits.shape)

# Prepare filter

Expand Down

0 comments on commit 8ce2970

Please sign in to comment.