Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ExLlamaV2Sampler.Settings.logits_processor #634

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

lapp0
Copy link

@lapp0 lapp0 commented Sep 23, 2024

Overview / Motivation

Implements ExLlamaV2Sampler.Settings.logits_processor which allows us to take advantage of third party libraries logits processors, such as Outlines which implements JSON Schema, regex, and Lark structured generation logits processors

Changes

  • Introduce ExLlamaV2Sampler.Settings.logits_processor which allows for logits filtering and augmentation with torch
    • Make necessary changes to generation-related code to ensure logits processor is applied
  • Introduce tests/test_logits_processors.py which is the same as tests/test.py but using a logits processor for all sampling
  • Introduce examples/json_schema_outlines.py

Performance

  • No change in performance for tests.py between this branch and master
  • Slower when optional logits_processor argument is enabled
    • This is because we're sampling twice, once with sample_basic and once with torch.
    • I've experimented with applying logits processors before moving logits to cpu. It improves normal performance to 145tokens/sec
(Tokens / Second) normal streaming (prompt) streaming (response) batched
master -> tests.py 157.71 3491.57 179.95 10.14
this branch -> tests.py 158.39 3559.52 178.04 10.15
this branch -> test_logits_processor.py 122.66 3576.11 134.56 9.97

master -> tests.py

Generating, normal
...
Response generated in 1.51 seconds, 150 tokens, 99.38 tokens/second

Generating, streaming
...
Prompt processed in 0.00 seconds, 15 tokens, 3693.04 tokens/second
Response generated in 1.10 seconds, 150 tokens, 136.67 tokens/second

Generating, batched
...
Response generated in 17.87 seconds, 40 tokens, throughput 8.95 tokens/second

(Note: Generating, batched multi cache fails in master, not due to this PR)

sampler-logits-processor -> tests.py

Generating, normal
...
Response generated in 1.51 seconds, 150 tokens, 99.46 tokens/second

Generating, streaming
...
Prompt processed in 0.00 seconds, 15 tokens, 3534.72 tokens/second                                                            
Response generated in 1.12 seconds, 150 tokens, 133.63 tokens/second                                                          

Generating, batched
...
Response generated in 18.26 seconds, 40 tokens, throughput 8.76 tokens/second

sampler-logits-processor -> test_logits_processor.py

Generating, normal
...
Response generated in 1.22 seconds, 150 tokens, 122.66 tokens/second

Generating, streaming
...
Prompt processed in 0.00 seconds, 15 tokens, 3576.11 tokens/second
Response generated in 1.11 seconds, 150 tokens, 134.56 tokens/second

Generating, batched
...
Response generated in 16.04 seconds, 40 tokens, throughput 9.97 tokens/second

Tests

All tests pass except for tests/test.py / Generating, batched multi cache which also fails in master

Generating, batched multi cache

Traceback (most recent call last):
  File "/root/exllamav2/tests/test.py", line 249, in <module>
    tests(q_model_directory, False, False)
  File "/root/exllamav2/tests/test.py", line 241, in tests
    if model.is_quant(): test_multicache(40)
  File "/root/exllamav2/tests/test.py", line 211, in test_multicache
    logits = model.forward(inputs, caches, input_mask = None).float().cpu()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/exllamav2/exllamav2/model.py", line 809, in forward
    result = self.forward_chunk(input_ids = input_ids,
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/root/exllamav2/exllamav2/model.py", line 922, in forward_chunk
    past_len = cache.current_seq_len
AttributeError: 'list' object has no attribute 'current_seq_len'

@turboderp
Copy link
Owner

This is interesting, and I'll be giving it a closer look later today. I'm a little skeptical, though, for a couple of reasons.

Logit processors tend to do a lot of extraneous work. Many operations and temporary allocations that could be a single iteration over a block of memory in the CPU's L2 cache (sometimes fitting in L1, even), or even literally one line of C++ code in some cases, turn into multiple kernel launches, each of which has to process the entire logit array after all but a few dozen options have been masked out. And you end up performing multiple softmax operations too if you want to combine samplers, since each processor has to output logits for the next processor in the stack.

Batched sampling would be a clear advantage in itself, except ExLlama doesn't require all sequences in a batch to use any of the same settings, so every processor would have to take batched parameters as well to take advantage of this. Not sure what's standard in that regard.

CPUs aren't that slow, either. You have AVX2 to help with anything that requires any real arithmetic (AVX512 is an option, too, blame Intel for screwing that one up for so many users), and you can split batches over multiple cores easily. I could also see issues arising from individual threads competing for the CUDA stream, unless logit processors were used exclusively and/or without multithreaded sampling enabled.

As for the Outlines example, currently with a library like Formatron, grammar constraints can be evaluated entirely in the background adding essentially zero overhead by using the dedicated filter interface. LMFE is written in Python which blocks multithreading, but it can still run while the CPU is waiting for the GPU to complete the forward pass. The straightforward way to use a logit processor as a grammar constraint doesn't really allow for concurrency of any kind. (I haven't checked, but I also doubt it uses pinned memory for the allowed token mask (?), forcing a sync point that would reduce any benefit from running the other processors on the GPU.)

But the main concern is that performance is going to suffer. Samplers in general are kind of irksome and (I feel) often ill-conceived, and this feels like opening the floodgates to a whole host of new issues and complaints.

I'll need to give it some careful consideration and run some tests, I suppose.

@lapp0
Copy link
Author

lapp0 commented Sep 25, 2024

Thanks for your thoughtful reply!

Your concerns about performance are valid, but for structured generation filtering, ExLlamaV2 lags behind both vLLM and Transformers. Recent benchmarks show that ExLlamaV2 incurs 2-15x the overhead compared to vLLM/Transformers. The key difference is that vLLM/Transformers support logits processors. In our own tests with Outlines, we saw a 50x performance boost by switching from list-based filtering to using a tensor of legal tokens within our logits processors.

Also, I’d like to reaffirm that with logits_processors disabled, there’s no performance difference between this branch and master.

Batched sampling would be a clear advantage in itself, except ExLlama doesn't require all sequences in a batch to use any of the same settings, so every processor would have to take batched parameters as well to take advantage of this. Not sure what's standard in that regard.

Based on this, and after some profiling, I agree that your current sampler implementation shouldn't be replaced with logits processors. The core benefit of this PR would be to take advantage of high-performance structured generation logits processors and reduce ExLlamaV2 overhead for that specific task.

Please let me know if I'm missing something or if you have any other questions.

@turboderp
Copy link
Owner

Your concerns about performance are valid, but for structured generation filtering, ExLlamaV2 lags behind both vLLM and Transformers

This may be the case for Outlines, idk. But with Formatron the overhead is negligible, often zero depending on model and batch size. It can even be net negative in some cases since sampling can be skipped when it's constrained to a single token.

The way the pipeline works, the constraint is evaluated while the forward pass is still completing on the GPU and the CPU is idle/busywaiting anyway. For grammar libraries that do the bulk of their work in C++ or Rust with the GIL released, it starts at the same time as the forward pass and runs completely in the background on other CPU cores. This means the final overhead is almost entirely from:

  • constructing a C++ vector from the valid tokens list (pybind)
  • iterating over the vector to mask out logits (C++)

There are several places this could be improved to reduce the overhead even further. But mostly it comes down to reducing the amount of time spent in the Python/Rust/C++ interop layers.

If you pass a Python list to a C++ function, whether it's the sampling logic in exllamav2_ext or an indexing operation in libtorch, it has to be unboxed one element at a time, and this is slow. A tensor reduces to a single pointer so it's thousands of times faster to pass as an argument. This really has nothing to do with CUDA, though, and it would be trivial to pass a mask tensor to ExLlama's sampler function instead of a list (provided the grammar library outputs such a tensor) eliminating most of the remaining overhead. For Formatron specifically, the Rust component internally produces a fixedbitset (i.e. a bit mask over the logits) before converting it to a list, and that would be even more efficient if it could be passed directly.

I'm not sure what the current ExLlama integrations for Outlines look like, though. But I do plan to revisit the grammar stuff soon, and see if there's a way to integrate it into the current filters pipeline.

@lapp0
Copy link
Author

lapp0 commented Sep 25, 2024

This may be the case for Outlines, idk. But with Formatron the overhead is negligible

These benchmarks are from the Formatron repo, they indicate that overhead with their vLLM integration (FormatronLogitsProcessor), there is overhead of 0.0 to 0.23 ms / token while their ExLlamaV2 integration has overhead of 0.17 to 1.46 ms / token. I might be missing something though, I haven't dug too deeply into Formatrons internals.

It can even be Dan-wanna-M/formatron#14 (comment) in some cases since sampling can be skipped when it's constrained to a single token.

Nice to see you have fast-forward implemented! I'll look further into this later since we'll need to consider how our implementations interface might best be suited for downstream consumption :)

I'm not sure what the current ExLlama integrations for Outlines look like, though. But I do plan to revisit the grammar stuff soon, and see if there's a way to integrate it into the current filters pipeline.

Currently we have a one logits processor per generation type (regex, grammars, json schema, etc). Each logits processor works with vLLM, transformers, mlxlm, llama.cpp, and hopefully ExLlamaV2 soon :). There is no distinct logits processor for any of these engines, their implementation is shared.

We've tested the outlines integration with this PR. Users would simply need to run

from outlines.processors import JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor
import outlines
from outlines.models.exllamav2 import patch_tokenizer as as_outlines_tokenizer

...
settings = ExLlamaV2Sampler.Settings()
settings.logits_processor = CFGLogitsProcessor(
    <your lark grammar>,
    as_outlines_tokenizer(generator.tokenizer)
)

I'll let you take some time to review this further. Please let me know if you have any questions or requested changes to help ensure this change conforms to your vision for the project!

@lapp0 lapp0 force-pushed the sampler-logits-processor branch 3 times, most recently from 212d4cd to 8ce2970 Compare October 4, 2024 21:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants