diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index bc974b10e8fb5..69a4ab42da9ac 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -224,8 +224,7 @@ __device__ inline void barrier_release(int* lock, bool reset = false) { __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows, - int num_threads) { + int size_k, int block_rows) { int start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; if (finish_row > size_m) { @@ -236,8 +235,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int row_stride = size_k * sizeof(half) / 16; auto permute_row = [&](int row) { - int iters = size_k / num_threads; - int rest = size_k % num_threads; + int iters = size_k / blockDim.x; + int rest = size_k % blockDim.x; int offset = row * row_stride; @@ -252,7 +251,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, out_half[cur_k] = a_row_half[src_pos]; - base_k += num_threads; + base_k += blockDim.x; } if (rest) { @@ -1218,8 +1217,7 @@ __global__ void MarlinMoE( __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows, - int num_threads) { + int size_k, int block_rows) { // Marlin is not implemented yet for SM < 8.0 assert(false); return; @@ -1513,8 +1511,7 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, int topk_rows = replicate_input ? tot_m : tot_m * topk; int block_rows = ceildiv(topk_rows, blocks); permute_cols_kernel<<>>( - A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows, - USER_THREADS); + A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); A_ptr = a_tmp_ptr; } diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index cf316c0f9afa4..195179df24941 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -315,6 +315,7 @@ def get_default_config( K: int, topk: int, dtype: Optional[str], + is_marlin: bool, ) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, @@ -322,7 +323,7 @@ def get_default_config( 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } - if M <= E: + if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, @@ -399,8 +400,14 @@ def grouped_topk( return topk_weights, topk_ids -def get_expert_config(w1: torch.Tensor, w2: torch.Tensor, topk: int, M: int, - N: int, E: int, use_fp8: bool): +def get_expert_config(w1: torch.Tensor, + w2: torch.Tensor, + topk: int, + M: int, + N: int, + E: int, + use_fp8: bool, + is_marlin: bool = False): # First try to load optimal config from the file configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) @@ -411,7 +418,7 @@ def get_expert_config(w1: torch.Tensor, w2: torch.Tensor, topk: int, M: int, else: # Else use the default config return get_default_config(M, E, N, w1.shape[2], topk, - "float8" if use_fp8 else None) + "float8" if use_fp8 else None, is_marlin) def fused_experts(hidden_states: torch.Tensor, @@ -655,7 +662,8 @@ def single_marlin_moe( if override_config: config = override_config else: - config = get_expert_config(w, w, topk_ids.shape[1], M, N, E, use_fp8) + config = get_expert_config(w, w, topk_ids.shape[1], M, N, E, use_fp8, + True) block_size_m = config['BLOCK_SIZE_M'] @@ -743,14 +751,15 @@ def fused_marlin_moe(hidden_states: torch.Tensor, if override_config: config = override_config else: - config = get_expert_config(w1, w2, topk_ids.shape[1], M, N, E, use_fp8) + config = get_expert_config(w1, w2, topk_ids.shape[1], M, N, E, use_fp8, + True) block_size_m = config['BLOCK_SIZE_M'] sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, block_size_m, E) - max_workspace_size = (N // 64) * 16 + max_workspace_size = (max(N, K) // 64) * 16 workspace = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda",