diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7ab9c2056bc36..7babffc62f431 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -192,7 +192,9 @@ steps: - vllm/model_executor/layers - vllm/sampling_metadata.py - tests/samplers - command: pytest -v -s samplers + commands: + - pytest -v -s samplers + - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers - label: LogitsProcessor Test # 5min mirror_hardwares: [amd] diff --git a/Dockerfile b/Dockerfile index 2a44ac3c24e67..c13cb5c7e7a95 100644 --- a/Dockerfile +++ b/Dockerfile @@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl #################### vLLM installation IMAGE #################### diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 820fb554888f0..719254a398c03 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -8,6 +8,7 @@ import torch from transformers import GenerationConfig, GenerationMixin +import vllm.envs as envs from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed @@ -634,7 +635,10 @@ def mock_sample(probs, *args, **kwargs): return ([[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs], None) - with patch("vllm.model_executor.layers.sampler._sample", mock_sample): + # top-k and top-p is only calculated when flashinfer kernel is not available + with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \ + patch("vllm.model_executor.layers.sampler." + "flashinfer_top_k_top_p_sampling", None): sampler(logits=fake_logits, sampling_metadata=sampling_metadata) assert sample_probs is not None @@ -645,6 +649,37 @@ def mock_sample(probs, *args, **kwargs): assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_flashinfer_fallback(seed: int, device: str): + if not envs.VLLM_USE_FLASHINFER_SAMPLER: + pytest.skip("Flashinfer sampler is disabled") + + set_random_seed(seed) + torch.set_default_device(device) + batch_size = random.randint(1, 256) + _, fake_logits, sampler = _prepare_test(batch_size) + + def failing_flashinfer_sampling(*_args, **_kwargs): + return None, torch.zeros(batch_size, device=device, dtype=torch.int32) + + sampling_params = SamplingParams( + temperature=1.0, + n=random.randint(1, 10), + seed=random.randint(0, 10000), + ) + sampler_output = _do_sample(batch_size, fake_logits, sampler, + sampling_params, device) + + with patch( + "vllm.model_executor.layers.sampler." + "flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling): + fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler, + sampling_params, device) + + assert sampler_output == fallback_sampler_output + + @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_repetition_penalty_mixed(device: str): diff --git a/vllm/envs.py b/vllm/envs.py index b0cb56e58d0da..115ead01f537d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -30,6 +30,7 @@ VLLM_LOGGING_CONFIG_PATH: Optional[str] = None VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None + VLLM_USE_FLASHINFER_SAMPLER: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" @@ -256,6 +257,10 @@ def get_default_config_root(): "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + # If set, vllm will use flashinfer sampler + "VLLM_USE_FLASHINFER_SAMPLER": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))), + # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 41abdf211e7e7..7344d59e988f0 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,5 +1,7 @@ """A layer that samples the next tokens from the model's outputs.""" import itertools +import warnings +from importlib.util import find_spec from math import inf from typing import Dict, List, Optional, Tuple @@ -11,6 +13,7 @@ if HAS_TRITON: from vllm.model_executor.layers.ops.sample import sample as sample_triton +import vllm.envs as envs from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, SequenceGroupToSample) @@ -19,6 +22,16 @@ PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceOutput) +if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): + import flashinfer.sampling + # yapf: disable + from flashinfer.sampling import ( + top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) + + # yapf: enable +else: + flashinfer_top_k_top_p_sampling = None + # (num_token_ids, num_parent_ids) per sequence group. SampleResultType = List[Tuple[List[int], List[int]]] @@ -123,7 +136,7 @@ def forward( logits = logits.to(torch.float) logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) - if do_top_p_top_k: + if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) @@ -476,14 +489,7 @@ def _multinomial( seq_groups: Optional[List[SequenceGroupToSample]] = None, ) -> torch.Tensor: if num_samples > 1: - # This is equivalent to torch.repeat_interleaved (which also - # forces a GPU<->CPU sync). - # This allows us to do sampling with replacement by creating - # num_samples copies of each row in the tensor, and then - # batch sampling the resulting tensor. - probs = probs[:, None, :].expand(probs.shape[0], num_samples, - probs.shape[1]).contiguous().view( - -1, probs.shape[1]) + probs = probs.repeat_interleave(num_samples, dim=0) q = torch.empty_like(probs) if seq_groups is None: q.exponential_() @@ -491,17 +497,57 @@ def _multinomial( sample_idx = 0 for seq_group in seq_groups: seq_ids = seq_group.seq_ids - next_sample_idx = sample_idx + len(seq_ids) * num_samples - q[sample_idx:next_sample_idx].exponential_( - generator=seq_group.generator) - sample_idx = next_sample_idx + stride = len(seq_ids) * num_samples + assert seq_group.generator is not None + q[sample_idx:sample_idx + + stride].exponential_(generator=seq_group.generator) + sample_idx += stride return probs.div_(q).argmax(dim=1).view(-1, num_samples) +def _top_k_top_p_multinomial_with_flashinfer( + probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, + num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]): + max_top_k_round = 32 + if num_samples > 1: + probs = probs.repeat_interleave(num_samples, dim=0) + top_ks = top_ks.repeat_interleave(num_samples) + top_ps = top_ps.repeat_interleave(num_samples) + batch_size = probs.shape[0] + uniform_samples = torch.empty((max_top_k_round, batch_size), + device=probs.device) + if seq_groups is None: + uniform_samples.uniform_() + else: + sample_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids + stride = len(seq_ids) * num_samples + assert seq_group.generator is not None + uniform_samples[:, sample_idx:sample_idx + + stride].uniform_(generator=seq_group.generator) + sample_idx += stride + batch_next_token_ids, success = flashinfer_top_k_top_p_sampling( + probs, + uniform_samples, + top_ks, + top_ps, + ) + if not success.all(): + warnings.warn("FlashInfer rejection sampling failed, fallback.", + stacklevel=1) + probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks) + probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps) + batch_next_token_ids = flashinfer.sampling.sampling_from_probs( + probs, uniform_samples[0]) + return batch_next_token_ids.view(-1, num_samples) + + def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, ) -> Tuple[SampleResultType, Optional[torch.Tensor]]: @@ -564,18 +610,28 @@ def _sample_with_torch( sampling_params = seq_group.sampling_params max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) - seeded_args = {} if sampling_type == SamplingType.RANDOM else { - "seq_groups": seq_groups, - } - - multinomial_samples[sampling_type] = _multinomial( - probs[long_sample_indices], max_best_of_in_batch, - **seeded_args) + seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else + seq_groups) + + if flashinfer_top_k_top_p_sampling is not None: + multinomial_samples[ + sampling_type] = _top_k_top_p_multinomial_with_flashinfer( + probs[long_sample_indices], + sampling_tensors.top_ks[long_sample_indices], + sampling_tensors.top_ps[long_sample_indices], + max_best_of_in_batch, + seq_groups_arg, + ) + else: + multinomial_samples[sampling_type] = _multinomial( + probs[long_sample_indices], + max_best_of_in_batch, + seq_groups=seq_groups_arg) if sampled_token_ids_tensor is not None: # Store sampled tokens in output tensor. - sampled_token_ids_tensor[ - long_sample_indices] = multinomial_samples[sampling_type] + sampled_token_ids_tensor[long_sample_indices] = \ + multinomial_samples[sampling_type].to(torch.long) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] @@ -693,9 +749,12 @@ def _sample_with_triton_kernel( def _sample( - probs: torch.Tensor, logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, modify_greedy_probs: bool + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, ) -> Tuple[SampleResultType, Optional[torch.Tensor]]: """ Args: @@ -713,6 +772,7 @@ def _sample( probs, logprobs, sampling_metadata, + sampling_tensors, include_gpu_probs_tensor=include_gpu_probs_tensor, modify_greedy_probs=modify_greedy_probs, )