Skip to content

Commit

Permalink
[Kernel][CPU] Add Quick gelu to CPU (vllm-project#5717)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored Jun 21, 2024
1 parent d9a252b commit bd620b0
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
19 changes: 19 additions & 0 deletions csrc/cpu/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
return w3 * x * (ones + t);
}

FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(1.702f);
return x / (ones + (zeros - w1 * x).exp());
}

FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2);
Expand Down Expand Up @@ -142,3 +149,15 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
});
}

void gelu_quick(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);

VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_quick_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_quick_impl)
activation_kernel<scalar_t, gelu_quick_act, false>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_quick_impl)
});
}
4 changes: 4 additions & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);

// Quick GELU implementation.
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCPU, &gelu_quick);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
3 changes: 3 additions & 0 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))

# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:

def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
ops.gelu_quick(out, x)
return out

# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
Expand Down

0 comments on commit bd620b0

Please sign in to comment.