Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] KV cache reuse is slower when batch size > 1 #2631

Open
ReginaZh opened this issue Dec 26, 2024 · 0 comments
Open

[Performance] KV cache reuse is slower when batch size > 1 #2631

ReginaZh opened this issue Dec 26, 2024 · 0 comments
Labels
Investigating KV-Cache Management triaged Issue has been triaged by maintainers

Comments

@ReginaZh
Copy link
Contributor

ReginaZh commented Dec 26, 2024

I did some experiments about kv cache reuse. When batch size = 1, the engine inference latency will decrease as the length of the common prefix increases. However, when the batch size is greater than 1, no matter how long the common prefix is, the latency of enabling kv cache reuse is always greater than that of disabling kv cache reuse.

Here are scripts to reproduce my result:
build engine

cd examples/qwen
git clone https://huggingface.co/Qwen/Qwen2.5-0.5B
python ./convert_checkpoint.py --model_dir Qwen2.5-0.5B --output_dir ./tmp/sq0.5 --dtype float16 --smoothquant 0.5 --per_token --per_channel
trtllm-build --checkpoint_dir ./tmp/sq0.5 --output_dir ./trt_engines --gemm_plugin float16 --gpt_attention_plugin float16 --max_input_len 384 --max_seq_len 385 --max_batch_size 7 --gather_generation_logits --use_paged_context_fmha enable

engine inference

runner_kwargs = dict(engine_dir=engine_dir)
runner_kwargs.update(
                  max_batch_size=batch_size,
                  max_input_len=384,
                  max_output_len=1
              )
 runner_kwargs.update(kv_cache_enable_block_reuse=True)
 runner = ModelRunnerCpp.from_dir(**runner_kwargs)

    
 for batch in batches:
     batch_input_ids = [torch.IntTensor(inp) for inp in batch]
     outputs = runner.generate(
                     batch_input_ids=batch_input_ids,
                     max_new_tokens=1,
                     return_dict=True)

result on A100, max input length=384, max output length=1, common prefix token = 128

Setting Latency
bs=1 enable kv cache reuse 5.03ms
bs=1 disable kv cache reuse 5.13ms
bs=7 enable kv cache reuse 33.16ms
bs=7 disable kv cache reuse 14.72ms

result on A100, max input length=384, max output length=1, common prefix token = 256

Setting Latency
bs=1 enable kv cache reuse 4.56ms
bs=1 disable kv cache reuse 5.13ms
bs=7 enable kv cache reuse 30.29ms
bs=7 disable kv cache reuse 14.72ms

In addition, when I set the common prefix length equal to the input length—meaning the same request is used for inference—the latency remains at 4ms. Is this expected?

@github-actions github-actions bot added triaged Issue has been triaged by maintainers Investigating labels Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Investigating KV-Cache Management triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants