diff --git a/exllamav2/generator/sampler.py b/exllamav2/generator/sampler.py index a6bc113e..0aee2a5b 100644 --- a/exllamav2/generator/sampler.py +++ b/exllamav2/generator/sampler.py @@ -364,8 +364,12 @@ def sample( # Apply logits processor if settings.logits_processor: - generated_ids = sequence_ids[:, input_ids.shape[1]:] - logits = settings.logits_processor(generated_ids, logits) + generated_ids = sequence_ids[:, input_ids.shape[-1]:] + # normalize to 2d + logits_2d = logits.view(-1, logits.shape[-1]) + generated_ids_2d = generated_ids.view(logits_2d.shape[0], generated_ids.shape[-1]) + # process logits and convert back to original logits shape + logits = settings.logits_processor(generated_ids_2d, logits_2d).view(logits.shape) # Prepare filter