From bd620b01fb74d5269ca6fc0fd32f66bfb205a358 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Thu, 20 Jun 2024 23:39:40 -0700 Subject: [PATCH] [Kernel][CPU] Add Quick `gelu` to CPU (#5717) --- csrc/cpu/activation.cpp | 19 +++++++++++++++++++ csrc/cpu/torch_bindings.cpp | 4 ++++ vllm/_ipex_ops.py | 3 +++ vllm/model_executor/layers/activation.py | 3 +++ 4 files changed, 29 insertions(+) diff --git a/csrc/cpu/activation.cpp b/csrc/cpu/activation.cpp index becd2ac42f17a..039b8d5c30d46 100644 --- a/csrc/cpu/activation.cpp +++ b/csrc/cpu/activation.cpp @@ -59,6 +59,13 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) { return w3 * x * (ones + t); } +FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) { + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + const vec_op::FP32Vec8 w1(1.702f); + return x / (ones + (zeros - w1 * x).exp()); +} + FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) { const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 w1(M_SQRT1_2); @@ -142,3 +149,15 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input) { CPU_KERNEL_GUARD_OUT(gelu_fast_impl) }); } + +void gelu_quick(torch::Tensor& out, torch::Tensor& input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1); + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] { + CPU_KERNEL_GUARD_IN(gelu_quick_impl) + activation_kernel( + num_tokens, d, input.data_ptr(), out.data_ptr()); + CPU_KERNEL_GUARD_OUT(gelu_quick_impl) + }); +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index a2bf0d49adba5..39e8cf3ed3c10 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -58,6 +58,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_fast", torch::kCPU, &gelu_fast); + // Quick GELU implementation. + ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_quick", torch::kCPU, &gelu_quick); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index 1e60e0848673b..99a875c9b3fb7 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -43,6 +43,9 @@ def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: out.copy_(torch.nn.functional.gelu(x)) + # TODO add implementation of gelu_quick here + # def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + def paged_attention_v1( out: torch.Tensor, query: torch.Tensor, diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 80cad15b43426..5bfdba67b443d 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -155,6 +155,9 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_quick(out, x) return out + # TODO implement forward_xpu for QuickGELU + # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + class ScaledActivation(nn.Module): """An activation function with post-scale parameters.