Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
more efficient m blocking, a couple small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Jul 23, 2024
1 parent d8b455f commit 7504696
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
15 changes: 6 additions & 9 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows,
int num_threads) {
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
Expand All @@ -236,8 +235,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int row_stride = size_k * sizeof(half) / 16;

auto permute_row = [&](int row) {
int iters = size_k / num_threads;
int rest = size_k % num_threads;
int iters = size_k / blockDim.x;
int rest = size_k % blockDim.x;

int offset = row * row_stride;

Expand All @@ -252,7 +251,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,

out_half[cur_k] = a_row_half[src_pos];

base_k += num_threads;
base_k += blockDim.x;
}

if (rest) {
Expand Down Expand Up @@ -1218,8 +1217,7 @@ __global__ void MarlinMoE(
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows,
int num_threads) {
int size_k, int block_rows) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
Expand Down Expand Up @@ -1513,8 +1511,7 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C,
int topk_rows = replicate_input ? tot_m : tot_m * topk;
int block_rows = ceildiv(topk_rows, blocks);
permute_cols_kernel<<<blocks, num_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows,
USER_THREADS);
A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows);
A_ptr = a_tmp_ptr;
}

Expand Down
23 changes: 16 additions & 7 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,15 @@ def get_default_config(
K: int,
topk: int,
dtype: Optional[str],
is_marlin: bool,
) -> Dict[str, int]:
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
if M <= E:
if M <= E or (is_marlin and M <= 32):
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
Expand Down Expand Up @@ -399,8 +400,14 @@ def grouped_topk(
return topk_weights, topk_ids


def get_expert_config(w1: torch.Tensor, w2: torch.Tensor, topk: int, M: int,
N: int, E: int, use_fp8: bool):
def get_expert_config(w1: torch.Tensor,
w2: torch.Tensor,
topk: int,
M: int,
N: int,
E: int,
use_fp8: bool,
is_marlin: bool = False):
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)

Expand All @@ -411,7 +418,7 @@ def get_expert_config(w1: torch.Tensor, w2: torch.Tensor, topk: int, M: int,
else:
# Else use the default config
return get_default_config(M, E, N, w1.shape[2], topk,
"float8" if use_fp8 else None)
"float8" if use_fp8 else None, is_marlin)


def fused_experts(hidden_states: torch.Tensor,
Expand Down Expand Up @@ -655,7 +662,8 @@ def single_marlin_moe(
if override_config:
config = override_config
else:
config = get_expert_config(w, w, topk_ids.shape[1], M, N, E, use_fp8)
config = get_expert_config(w, w, topk_ids.shape[1], M, N, E, use_fp8,
True)

block_size_m = config['BLOCK_SIZE_M']

Expand Down Expand Up @@ -743,14 +751,15 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
if override_config:
config = override_config
else:
config = get_expert_config(w1, w2, topk_ids.shape[1], M, N, E, use_fp8)
config = get_expert_config(w1, w2, topk_ids.shape[1], M, N, E, use_fp8,
True)

block_size_m = config['BLOCK_SIZE_M']

sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, block_size_m, E)

max_workspace_size = (N // 64) * 16
max_workspace_size = (max(N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
Expand Down

2 comments on commit 7504696

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smaller_is_better

Benchmark suite Current: 7504696 Previous: 9daca33 Ratio
{"name": "mean_ttft_ms", "description": "VLLM Serving - Dense\nmodel - meta-llama/Meta-Llama-3-8B-Instruct\nmax-model-len - 4096\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 179.69725370664187 ms 186.40082483667356 ms 0.96
{"name": "mean_tpot_ms", "description": "VLLM Serving - Dense\nmodel - meta-llama/Meta-Llama-3-8B-Instruct\nmax-model-len - 4096\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 84.35483975280107 ms 85.41327620064654 ms 0.99
{"name": "mean_ttft_ms", "description": "VLLM Serving - Dense\nmodel - facebook/opt-350m\nmax-model-len - 2048\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 24.484588523361403 ms 23.54242705666896 ms 1.04
{"name": "mean_tpot_ms", "description": "VLLM Serving - Dense\nmodel - facebook/opt-350m\nmax-model-len - 2048\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA L4 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 6.200602235995038 ms 6.0915448100750496 ms 1.02

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

smaller_is_better

Benchmark suite Current: 7504696 Previous: 9daca33 Ratio
{"name": "mean_ttft_ms", "description": "VLLM Serving - Dense\nmodel - facebook/opt-350m\nmax-model-len - 2048\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA H100 80GB HBM3 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 41.31513532716781 ms
{"name": "mean_tpot_ms", "description": "VLLM Serving - Dense\nmodel - facebook/opt-350m\nmax-model-len - 2048\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA H100 80GB HBM3 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 7.670629939890346 ms
{"name": "mean_ttft_ms", "description": "VLLM Serving - Dense\nmodel - meta-llama/Meta-Llama-3-8B-Instruct\nmax-model-len - 4096\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA H100 80GB HBM3 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 32.59767034556717 ms
{"name": "mean_tpot_ms", "description": "VLLM Serving - Dense\nmodel - meta-llama/Meta-Llama-3-8B-Instruct\nmax-model-len - 4096\nsparsity - None\nbenchmark_serving {\n \"nr-qps-pair_\": \"300,1\",\n \"dataset\": \"sharegpt\"\n}", "gpu_description": "NVIDIA H100 80GB HBM3 x 1", "vllm_version": "0.5.1", "python_version": "3.10.12 (main, Jun 7 2023, 13:43:11) [GCC 11.3.0]", "torch_version": "2.3.0+cu121"} 11.784972186111276 ms

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.