diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py new file mode 100644 index 0000000000000..f439afa9b7d2b --- /dev/null +++ b/tests/kernels/test_attention_selector.py @@ -0,0 +1,84 @@ +import os +from unittest.mock import patch + +import pytest +import torch + +from vllm.attention.selector import which_attn_to_use + + +@pytest.mark.parametrize( + "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) +@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) +def test_env(name: str, device: str): + """Test that the attention selector can be set via environment variable. + Note that we do not test FlashAttn because it is the default backend. + """ + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = name + + if device == "cpu": + with patch("vllm.attention.selector.is_cpu", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "TORCH_SDPA" + elif device == "hip": + with patch("vllm.attention.selector.is_hip", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "ROCM_FLASH" + else: + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == name + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + +def test_flash_attn(): + """Test FlashAttn validation.""" + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" + + # Unsupported CUDA arch + with patch("torch.cuda.get_device_capability", return_value=[7, 5]): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported data type + backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported kv cache data type + backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported block size + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8) + assert backend.name != "FLASH_ATTN" + + # Unsupported sliding window + backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # flash-attn is not installed + with patch.dict('sys.modules', {'vllm_flash_attn': None}): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + # Unsupported head size + backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16) + assert backend.name != "FLASH_ATTN" + + if name_backup is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup + + +def test_invalid_env(): + """Throw an exception if the backend name is invalid.""" + name_backup = os.environ.get("VLLM_ATTENTION_BACKEND", None) + os.environ["VLLM_ATTENTION_BACKEND"] = "INVALID" + with pytest.raises(ValueError): + which_attn_to_use(8, 16, 8, None, torch.float16, None, 16) + os.environ["VLLM_ATTENTION_BACKEND"] = name_backup diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7210fefbd8162..7b7959d257fac 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -218,6 +218,7 @@ def forward( ) if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. assert prefill_meta.block_tables is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: output = flash_attn_varlen_func( diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 51c25a81b4130..f191461dcd3b7 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -30,24 +30,16 @@ def get_attn_backend( kv_cache_dtype: Optional[str], block_size: int, ) -> Type[AttentionBackend]: - backend = _which_attn_to_use(num_heads, head_size, num_kv_heads, - sliding_window, dtype, kv_cache_dtype, - block_size) + """Determine which attention backend to use and only import + the selected backend module. + """ + backend = which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype, kv_cache_dtype, + block_size) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) - - # We check it here not in _which_attn_to_use because we cannot know - # the head size until we import FlashAttentionBackend. - supported_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size in supported_head_sizes: - logger.info("Using FlashAttention-2 backend.") - return FlashAttentionBackend - logger.info( - "Cannot use FlashAttention-2 backend for head size %d. " - "Using XFormers backend instead.", head_size) - backend = _Backend.XFORMERS - + return FlashAttentionBackend if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -64,14 +56,15 @@ def get_attn_backend( return TorchSDPABackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning("Eager mode is enforced for the Flashinfer backend.") + logger.warning("Eager mode is required for the Flashinfer backend. " + "Please make sure --enforce-eager is set.") from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend else: raise ValueError("Invalid attention backend.") -def _which_attn_to_use( +def which_attn_to_use( num_heads: int, head_size: int, num_kv_heads: int, @@ -81,54 +74,84 @@ def _which_attn_to_use( block_size: int, ) -> _Backend: """Returns which flash attention backend to use.""" + + # Default case. + selected_backend = _Backend.FLASH_ATTN + + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + backend_members = _Backend.__members__ + if backend_by_env_var not in backend_members: + raise ValueError( + f"Invalid attention backend '{backend_by_env_var}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + selected_backend = _Backend[backend_by_env_var] + if is_cpu(): + if selected_backend != _Backend.TORCH_SDPA: + logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA if is_hip(): # AMD GPUs. - if torch.cuda.get_device_capability()[0] != 9: - # not Instinct series GPUs. - logger.info("flash_atten is not supported on NAVI GPUs.") + selected_backend = (_Backend.ROCM_FLASH if selected_backend + == _Backend.FLASH_ATTN else selected_backend) + if selected_backend == _Backend.ROCM_FLASH: + if torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_attn is not supported on NAVI GPUs.") + else: + logger.info("%s is not supported in AMD GPUs.", selected_backend) return _Backend.ROCM_FLASH - # NVIDIA GPUs. - if torch.cuda.get_device_capability()[0] < 8: - # Volta and Turing NVIDIA GPUs. - logger.info("Cannot use FlashAttention-2 backend for Volta and Turing " - "GPUs.") - return _Backend.XFORMERS - - if dtype not in (torch.float16, torch.bfloat16): - logger.info("Cannot use FlashAttention-2 backend for dtype other than " - "torch.float16 or torch.bfloat16.") - return _Backend.XFORMERS - - if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): - logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.") - return _Backend.XFORMERS - - if block_size % 16 != 0: - logger.info("Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - return _Backend.XFORMERS - - if sliding_window is not None: - logger.info( - "Cannot use FlashAttention-2 backend due to sliding window.") - return _Backend.XFORMERS - - try: - import vllm_flash_attn # noqa: F401 - except ImportError: - logger.info( - "Cannot use FlashAttention-2 backend because the vllm_flash_attn " - "package is not found. `pip install vllm-flash-attn` for better " - "performance.") - return _Backend.XFORMERS - - backend_by_env_var = envs.VLLM_ATTENTION_BACKEND - if backend_by_env_var is not None: - return _Backend[backend_by_env_var] - - # Default case. - return _Backend.FLASH_ATTN + # FlashAttn in NVIDIA GPUs. + if selected_backend == _Backend.FLASH_ATTN: + if torch.cuda.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + selected_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + selected_backend = _Backend.XFORMERS + elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info( + "Cannot use FlashAttention-2 backend for FP8 KV cache.") + selected_backend = _Backend.XFORMERS + elif block_size % 16 != 0: + logger.info( + "Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + selected_backend = _Backend.XFORMERS + elif sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window.") + selected_backend = _Backend.XFORMERS + + # FlashAttn is valid for the model, checking if the package is installed. + if selected_backend == _Backend.FLASH_ATTN: + try: + import vllm_flash_attn # noqa: F401 + + from vllm.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size) + selected_backend = _Backend.XFORMERS + except ImportError: + logger.info( + "Cannot use FlashAttention-2 backend because the " + "vllm_flash_attn package is not found. " + "`pip install vllm-flash-attn` for better performance.") + selected_backend = _Backend.XFORMERS + + return selected_backend