From a897c8ee6c0c9fbf9c23f28e61fc09fc5454e919 Mon Sep 17 00:00:00 2001 From: Jiuzheng Wang Date: Sun, 31 Mar 2024 14:18:04 +0800 Subject: [PATCH] ft16 support for CuTe micro kernel --- .../micro_kernel/matmul/cutlass/gemm_sm80.h | 36 ++----------------- src/schedule/lower_cutlass_micro_block.cc | 33 ----------------- .../test_program_with_micro_kernel.py | 2 +- 3 files changed, 3 insertions(+), 68 deletions(-) diff --git a/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h b/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h index e1713aea9..4f8ab80e9 100644 --- a/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h +++ b/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h @@ -16,51 +16,19 @@ using namespace cute; template struct DispatchInstruction; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) template <> struct DispatchInstruction { using MMA = MMA_Atom; }; template <> struct DispatchInstruction { using MMA = MMA_Atom; }; -#endif template struct OperandTraits { - // Primary template, use padded layout and default copy static constexpr int stride = K_inner ? K : N; - static constexpr int padded = - stride % (256 / Bits) == 0 ? stride + 128 / Bits : stride; using Layout = typename std::conditional< - K_inner, Layout, Int>, Shape, _1>>, - Layout, Int>, Shape<_1, Int>>>::type; - using Copy = DefaultCopy; -}; -template -struct OperandTraits<16, N, K, true, - typename std::enable_if::type> { - using Layout = Layout, Int>, Stride, _1>>; - using Copy = DefaultCopy; -}; - -template -struct OperandTraits<16, N, K, false, - typename std::enable_if::type> { - using Layout = Layout, Int>, Stride<_1, Int>>; - using Copy = DefaultCopy; -}; - -template -struct OperandTraits<64, N, K, true, - typename std::enable_if::type> { - using Layout = Layout, Int>, Stride, _1>>; - using Copy = DefaultCopy; -}; - -template -struct OperandTraits<64, N, K, false, - typename std::enable_if::type> { - using Layout = Layout, Int>, Stride<_1, Int>>; + K_inner, Layout, Int>, Shape, _1>>, + Layout, Int>, Shape<_1, Int>>>::type; using Copy = DefaultCopy; }; diff --git a/src/schedule/lower_cutlass_micro_block.cc b/src/schedule/lower_cutlass_micro_block.cc index 9e4b8f5c2..7bce8d21d 100644 --- a/src/schedule/lower_cutlass_micro_block.cc +++ b/src/schedule/lower_cutlass_micro_block.cc @@ -165,7 +165,6 @@ class LowerCutlassMicroBlock : public SymbolTable { int nDimsCAll = op->indices_.size(); ASSERT(nDimsCAll >= 9); // See comments in `lowerCutlassMicroBlock` below -<<<<<<< HEAD switch (DType) { case BaseDataType::Float64: { auto batchInWarpPartition = @@ -216,26 +215,6 @@ class LowerCutlassMicroBlock : public SymbolTable { break; } } -======= - auto batchInWarpPartition = - makeEQ(op->indices_[nDimsCAll - 9], prop_->warpIdBatch_); - auto mInWarpPartition = - makeEQ(op->indices_[nDimsCAll - 4], prop_->warpIdM_); - auto nInWarpPartition = - makeEQ(op->indices_[nDimsCAll - 5], prop_->warpIdN_); - auto mInThreadPartition = - makeEQ(op->indices_[nDimsCAll - 3], - makeFloorDiv(prop_->laneId_, makeIntConst(4))); - auto nInThreadPartition = - makeEQ(op->indices_[nDimsCAll - 2], - makeMod(prop_->laneId_, makeIntConst(4))); - - ret = makeIf( - makeLAnd(makeLAnd(batchInWarpPartition, - makeLAnd(mInWarpPartition, nInWarpPartition)), - makeLAnd(mInThreadPartition, nInThreadPartition)), - ret); ->>>>>>> master } return ret; } @@ -280,17 +259,6 @@ class LowerCutlassMicroBlock : public SymbolTable { int nDimsCAll = c->indices_.size(); ASSERT(nDimsCAll >= 9); // See comments in `lowerCutlassMicroBlock` below -<<<<<<< HEAD -======= - c->indices_[nDimsCAll - 9] = warpIdBatch; - c->indices_[nDimsCAll - 4] = warpIdM; // m warps - c->indices_[nDimsCAll - 3] = - makeFloorDiv(laneId, makeIntConst(4)); // m threads - c->indices_[nDimsCAll - 5] = warpIdN; // n warps - c->indices_[nDimsCAll - 2] = - makeMod(laneId, makeIntConst(4)); // n threads ->>>>>>> master - switch (DType) { case BaseDataType::Float64: { c->indices_[nDimsCAll - 9] = warpIdBatch; @@ -504,7 +472,6 @@ Stmt lowerCutlassMicroBlock(const Stmt &_ast, const ID &matMulId, } } - // Lower to CutlassMicroThread LowerCutlassMicroBlock lowerCutlassMicroBlock{matMulId, nWarpBatch, nWarpM, nWarpN}; diff --git a/test/70.program/test_program_with_micro_kernel.py b/test/70.program/test_program_with_micro_kernel.py index 7faf1a3b1..bfa33d91b 100644 --- a/test/70.program/test_program_with_micro_kernel.py +++ b/test/70.program/test_program_with_micro_kernel.py @@ -93,4 +93,4 @@ def matmul(a: ft.Var[(M, K), "float16"], b: ft.Var[(K, N), "float16"]): b_arr = ft.array(b_torch) y_arr = exe(a_arr, b_arr) y_torch = y_arr.torch() - assert torch.all(torch.isclose(y_torch, y_std, rtol = 2e-2)) + assert torch.all(torch.isclose(y_torch, y_std, rtol=2e-2))