Skip to content

Commit

Permalink
Migrate CUTLASS micro kernel to CuTe (#601)
Browse files Browse the repository at this point in the history
* Migrate CUTLASS micro kernel to CuTe

* Remove currently unsupported specializations

* Workaround a cmake dependency issue

---------

Co-authored-by: Jiuzheng Wang <[email protected]>
Co-authored-by: Shizhi Tang <[email protected]>
  • Loading branch information
3 people authored Mar 27, 2024
1 parent 50909cc commit b13775e
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 138 deletions.
11 changes: 9 additions & 2 deletions include/schedule/var_reorder.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ class VarReorder : public SymbolTable<Mutator> {
ID def_;
std::string var_;
std::vector<int> order_;
bool forceReorderInMatMul_;
bool found_ = false;

public:
VarReorder(const ID &def, const std::vector<int> &order)
: def_(def), order_(order) {
VarReorder(const ID &def, const std::vector<int> &order,
bool forceReorderInMatMul)
: def_(def), order_(order),
forceReorderInMatMul_(forceReorderInMatMul) {
std::vector<int> numbers;
numbers.reserve(order.size());
for (int i = 0, n = order.size(); i < n; i++) {
Expand Down Expand Up @@ -58,6 +61,10 @@ class VarReorder : public SymbolTable<Mutator> {
Stmt visit(const MatMul &op) override;
};

Stmt varReorderImpl(const Stmt &ast, const ID &def,
const std::vector<int> &order,
bool forceReorderInMatMul = false);

Stmt varReorder(const Stmt &ast, const ID &def, const std::vector<int> &order);

} // namespace freetensor
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ doc = [
[build-system]
requires = [
"py-build-cmake~=0.1.8",
"importlib_metadata", # Workaround https://github.com/scikit-build/cmake-python-distributions/issues/471
# We can't use pybind11-stubgen here. It will break CMake's incremental compilation
"z3-solver",
"setuptools", # Required by z3: https://github.com/Z3Prover/z3/issues/2374
Expand Down
179 changes: 65 additions & 114 deletions runtime/micro_kernel/matmul/cutlass/gemm_sm80.h
Original file line number Diff line number Diff line change
@@ -1,131 +1,85 @@
/**
* This file is borrowed from
* https://github.com/nox-410/tvm.tl/blob/tl/src/tl/tl_templates/gemm_sm80.h
* https://github.com/nox-410/tvm.tl/blob/tl/src/tl/tl_templates/cute_gemm.h
* under Apache Lincense, and modified for use.
*/

#ifndef MICRO_KERNEL_MATMUL_CUTLASS_GEMM_SM80_H
#define MICRO_KERNEL_MATMUL_CUTLASS_GEMM_SM80_H
#pragma once

#include <cute/algorithm/copy.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/warp/mma_tensor_op.h>
#include <cutlass/numeric_types.h>

using cutlass::gemm::GemmShape;
using namespace cute;

template <typename A_type, typename B_type, typename C_type>
struct DispatchInstruction;

template <>
struct DispatchInstruction<cutlass::half_t, cutlass::half_t, cutlass::half_t> {
using Shape = GemmShape<16, 8, 16>;
};
template <>
struct DispatchInstruction<cutlass::half_t, cutlass::half_t, float> {
using Shape = GemmShape<16, 8, 16>;
};
template <>
struct DispatchInstruction<cutlass::bfloat16_t, cutlass::bfloat16_t, float> {
using Shape = GemmShape<16, 8, 16>;
};
template <>
struct DispatchInstruction<cutlass::tfloat32_t, cutlass::tfloat32_t, float> {
using Shape = GemmShape<16, 8, 8>;
};
template <> struct DispatchInstruction<double, double, double> {
using Shape = GemmShape<8, 8, 4>;
};
template <> struct DispatchInstruction<int8_t, int8_t, int> {
using Shape = GemmShape<16, 8, 32>;
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
};

template <bool transpose> struct DispatchSharedMemoryLayout;

template <> struct DispatchSharedMemoryLayout<true> {
using Layout = cutlass::layout::ColumnMajor;
};
template <> struct DispatchSharedMemoryLayout<false> {
using Layout = cutlass::layout::RowMajor;
template <int Bits, int N, int K, bool K_inner, typename Enable = void>
struct OperandTraits {
static constexpr int stride = K_inner ? K : N;
using Layout = typename std::conditional<
K_inner, Layout<Shape<Int<N>, Int<K>>, Shape<Int<K>, _1>>,
Layout<Shape<Int<N>, Int<K>>, Shape<_1, Int<N>>>>::type;
using Copy = DefaultCopy;
};

template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type, typename B_type, typename C_type>
class GemmTensorOp {
public:
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
cutlass::tfloat32_t, A_type_raw>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
cutlass::tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
using InstructionShape =
typename DispatchInstruction<A_type, B_type, C_type>::Shape;
using SMemLayoutA = typename DispatchSharedMemoryLayout<trans_A>::Layout;
using SMemLayoutB = typename DispatchSharedMemoryLayout<trans_B>::Layout;

using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<
InstructionShape, 32, A_type, cutlass::layout::RowMajor, B_type,
cutlass::layout::ColumnMajor, C_type, cutlass::layout::RowMajor,
cutlass::arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1>>;

static_assert(Shape::kM % num_warp_m == 0);
static_assert(Shape::kN % num_warp_n == 0);

using MmaWarp = typename cutlass::gemm::warp::MmaTensorOp<
GemmShape<Shape::kM / num_warp_m, Shape::kN / num_warp_n,
InstructionShape::kK>,
A_type, SMemLayoutA, B_type, SMemLayoutB, C_type,
cutlass::layout::RowMajor, Policy, 1,
true /* accumulate in row major */>;

using TensorRefA = typename MmaWarp::IteratorA::TensorRef;
using TensorRefB = typename MmaWarp::IteratorB::TensorRef;
using FragmentA = typename MmaWarp::FragmentA;
using FragmentB = typename MmaWarp::FragmentB;
using FragmentC = typename MmaWarp::FragmentC;
using IteratorA = typename MmaWarp::IteratorA;
using IteratorB = typename MmaWarp::IteratorB;

static_assert(Shape::kK % InstructionShape::kK == 0);
static int constexpr kKgroups = Shape::kK / InstructionShape::kK;

static CUTLASS_DEVICE void body(const A_type_raw *pA, const B_type_raw *pB,
FragmentC &accum, int lda, int ldb,
double alpha, double beta,
const int warp_idx_m, const int warp_idx_n,
const int lane_id) {
MmaWarp mma_op;
FragmentA frag_A;
FragmentB frag_B;
const TensorRefA ref_A((A_type *)pA, lda);
const TensorRefB ref_B((B_type *)pB, ldb);
IteratorA iter_A(ref_A, lane_id);
IteratorB iter_B(ref_B, lane_id);
iter_A.add_tile_offset({warp_idx_m, 0});
iter_B.add_tile_offset({0, warp_idx_n});

// TODO: Check all cases of alpha and beta
// TODO: Static checking of alpha and beta
if (beta == 0) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentC::kElements; i++) {
accum[i] = 0;
}
} else {
assert(beta == 1);
}

CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups; ++k) {
iter_A.load(frag_A);
iter_B.load(frag_B);
++iter_A;
++iter_B;
mma_op(accum, frag_A, frag_B, accum);
using Instruction = DispatchInstruction<A_type, B_type, C_type>;

using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
using OperandBTraits =
OperandTraits<sizeof_bits<B_type>::value, N, K, trans_B>;
using SmemLayoutA = typename OperandATraits::Layout;
using SmemLayoutB = typename OperandBTraits::Layout;
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

using TileMma =
TiledMMA<typename Instruction::MMA,
Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>>;

static CUTE_DEVICE void body(const A_type *pA, const B_type *pB, C_type *pC,
int lda, int ldb, double alpha, double beta,
int warp_id_m, int warp_id_n, int lane_id) {
int tid = (warp_id_n * num_warp_m + warp_id_m) * 32 + lane_id;
// change the layout!!!
Tensor sA = make_tensor(make_smem_ptr((A_type *)(pA)), SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr((B_type *)(pB)), SmemLayoutB{});
TileMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma);
auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma);
auto thr_copy_A = tiled_copy_A.get_thread_slice(tid);
auto thr_copy_B = tiled_copy_B.get_thread_slice(tid);

Tensor tCrA = thr_mma.partition_fragment_A(sA);
Tensor tCrB = thr_mma.partition_fragment_B(sB);
Tensor tCsA = thr_copy_A.partition_S(sA);
Tensor tCsB = thr_copy_B.partition_S(sB);

Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA);
Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB);

Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));

int num_tile_k = size<2>(tCrA);
CUTE_UNROLL
for (int k = 0; k < num_tile_k; ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
gemm(tiled_mma, tCrA(_, _, k), tCrB(_, _, k), acc);
}
}
};
Expand All @@ -138,12 +92,9 @@ CUTLASS_DEVICE void matmul_thread(const A_type *pA, const B_type *pB,
int strideb, int stridec, double alpha,
double beta, int warp_id_batch, int warp_id_m,
int warp_id_n, int lane_id) {
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n,
trans_A, trans_B, A_type, B_type, C_type>;
using FragmentC = typename MMA::FragmentC;
using MMA = GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A, trans_B,
A_type, B_type, C_type>;
MMA::body(pA + warp_id_batch * stridea, pB + warp_id_batch * strideb,
*(FragmentC *)(accum /* no thread offset */), lda, ldb, alpha,
beta, warp_id_m, warp_id_n, lane_id);
(accum /* no thread offset */), lda, ldb, alpha, beta, warp_id_m,
warp_id_n, lane_id);
}

#endif // MICRO_KERNEL_MATMUL_CUTLASS_GEMM_SM80_H
51 changes: 33 additions & 18 deletions src/schedule/lower_cutlass_micro_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <pass/shrink_for.h>
#include <schedule/lower_cutlass_micro_block.h>
#include <schedule/var_merge.h>
#include <schedule/var_reorder.h>
#include <schedule/var_split.h>
#include <schedule/var_unsqueeze.h>

Expand Down Expand Up @@ -162,11 +163,11 @@ class LowerCutlassMicroBlock : public SymbolTable<Mutator> {
auto batchInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 9], prop_->warpIdBatch_);
auto mInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 7], prop_->warpIdM_);
makeEQ(op->indices_[nDimsCAll - 4], prop_->warpIdM_);
auto nInWarpPartition =
makeEQ(op->indices_[nDimsCAll - 4], prop_->warpIdN_);
makeEQ(op->indices_[nDimsCAll - 5], prop_->warpIdN_);
auto mInThreadPartition =
makeEQ(op->indices_[nDimsCAll - 5],
makeEQ(op->indices_[nDimsCAll - 3],
makeFloorDiv(prop_->laneId_, makeIntConst(4)));
auto nInThreadPartition =
makeEQ(op->indices_[nDimsCAll - 2],
Expand Down Expand Up @@ -222,10 +223,12 @@ class LowerCutlassMicroBlock : public SymbolTable<Mutator> {
ASSERT(nDimsCAll >=
9); // See comments in `lowerCutlassMicroBlock` below
c->indices_[nDimsCAll - 9] = warpIdBatch;
c->indices_[nDimsCAll - 7] = warpIdM;
c->indices_[nDimsCAll - 5] = makeFloorDiv(laneId, makeIntConst(4));
c->indices_[nDimsCAll - 4] = warpIdN;
c->indices_[nDimsCAll - 2] = makeMod(laneId, makeIntConst(4));
c->indices_[nDimsCAll - 4] = warpIdM; // m warps
c->indices_[nDimsCAll - 3] =
makeFloorDiv(laneId, makeIntConst(4)); // m threads
c->indices_[nDimsCAll - 5] = warpIdN; // n warps
c->indices_[nDimsCAll - 2] =
makeMod(laneId, makeIntConst(4)); // n threads

op->backend_ = MatMulBackend::CutlassMicroThread;
op->cutlassMicroKernelProperty_ = prop_;
Expand Down Expand Up @@ -278,13 +281,13 @@ Stmt lowerCutlassMicroBlock(const Stmt &_ast, const ID &matMulId,
// ...: other leading dims,
// -9: batch warps,
// -8: batch serial,
// -7: m warps,
// -6: m 8-tiles,
// -5: m threads,
// -4: n warps,
// -3: n 8-tiles,
// -7: n 16-tiles,
// -6: m 16-tiles,
// -5: n warps
// -4: m warps,
// -3: m threads
// -2: n threads,
// -1: n 2-tiles
// -1: n 2-tiles,
// ]
//
// See
Expand Down Expand Up @@ -333,19 +336,31 @@ Stmt lowerCutlassMicroBlock(const Stmt &_ast, const ID &matMulId,
} else if (nDimsCBatch == 0) {
ast = varUnsqueeze(ast, defIdC, nDimsCOthers);
}

// clang-format off
ast = varSplit(
ast, defIdC, nDimsCOthers + 0, VarSplitMode::FixedSize, -1, nWarpBatch);
ast = varSplit(
ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize, 16, -1);
ast = varSplit(
ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize, -1, nWarpM);
ast = varSplit(
ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize, 8, -1);
ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize, -1, nWarpM);
ast = varSplit(
ast, defIdC, nDimsCOthers + 5, VarSplitMode::FixedSize, -1, nWarpN);
ast, defIdC, nDimsCOthers + 5, VarSplitMode::FixedSize, 16, -1);
ast = varSplit(
ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize, 8, -1);
ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize, -1, nWarpN);
ast = varSplit(
ast, defIdC, nDimsCOthers + 7, VarSplitMode::FixedSize, 2, -1);
std::vector<int> vec;
for(int i=0; i<=nDimsCOthers+1; i++)
vec.push_back(i);
vec.push_back(nDimsCOthers+5);
vec.push_back(nDimsCOthers+2);
vec.push_back(nDimsCOthers+6);
vec.push_back(nDimsCOthers+3);
vec.push_back(nDimsCOthers+4);
vec.push_back(nDimsCOthers+7);
vec.push_back(nDimsCOthers+8);
ast = varReorderImpl(ast, defIdC, vec, true);
// clang-format on

// Lower to CutlassMicroThread
Expand Down
13 changes: 9 additions & 4 deletions src/schedule/var_reorder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,28 @@ Expr VarReorder::visit(const Load &_op) {
}

Stmt VarReorder::visit(const MatMul &op) {
if (!var_.empty() && (allReads(op->equivalent_).count(var_) ||
if (!var_.empty() && !forceReorderInMatMul_ && (allReads(op->equivalent_).count(var_) ||
allWrites(op->equivalent_).count(var_))) {
throw InvalidSchedule("Please call var_reorder before as_matmul");
}
return BaseClass::visit(op);
}

Stmt varReorder(const Stmt &_ast, const ID &def,
const std::vector<int> &order) {
VarReorder mutator(def, order);
Stmt varReorderImpl(const Stmt &_ast, const ID &def,
const std::vector<int> &order, bool forceReorderInMatMul) {
VarReorder mutator(def, order, forceReorderInMatMul);
auto ast = mutator(_ast);
if (!mutator.found()) {
throw InvalidSchedule(FT_MSG << def << " not found");
}
return ast;
}

Stmt varReorder(const Stmt &ast, const ID &def,
const std::vector<int> &order) {
return varReorderImpl(ast, def, order);
}

void Schedule::varReorder(const ID &def, const std::vector<int> &order) {
beginTransaction();
auto log = appendLog(
Expand Down

0 comments on commit b13775e

Please sign in to comment.