Skip to content

Commit

Permalink
float16 support for cute micro kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
cerf-volantWang committed Mar 31, 2024
1 parent a897c8e commit d9d8b8b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 19 deletions.
25 changes: 7 additions & 18 deletions runtime/micro_kernel/matmul/cutlass/gemm_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,9 @@ struct OperandTraits {
};

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
bool trans_B, typename A_type, typename B_type, typename C_type>
class GemmTensorOp {
public:
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;

using Instruction = DispatchInstruction<A_type, B_type, C_type>;

using OperandATraits =
Expand All @@ -56,15 +47,13 @@ class GemmTensorOp {
using SmemCopyA = Copy_Atom<typename OperandATraits::Copy, A_type>;
using SmemCopyB = Copy_Atom<typename OperandBTraits::Copy, B_type>;

using TileMma = TiledMMA<typename Instruction::MMA,
Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>
/*,typename Instruction::MMA_Group*/>;

static CUTE_DEVICE void body(const A_type_raw *pA, const B_type_raw *pB,
C_type_raw *pC, int lda, int ldb, double alpha,
double beta, int warp_id_m, int warp_id_n,
int lane_id) {
using TileMma =
TiledMMA<typename Instruction::MMA,
Layout<Shape<Int<num_warp_m>, Int<num_warp_n>, _1>>>;

static CUTE_DEVICE void body(const A_type *pA, const B_type *pB, C_type *pC,
int lda, int ldb, double alpha, double beta,
int warp_id_m, int warp_id_n, int lane_id) {
int tid = (warp_id_n * num_warp_m + warp_id_m) * 32 + lane_id;
// change the layout!!!
Tensor sA = make_tensor(make_smem_ptr((A_type *)(pA)), SmemLayoutA{});
Expand Down
1 change: 0 additions & 1 deletion src/schedule/lower_cutlass_micro_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ class LowerCutlassMicroBlock : public SymbolTable<Mutator> {
if (inMicroKernel_) {
throw InvalidSchedule("Micro kernels cannot nest each other");
}

// Here we use `threadIdx.x` for threads in a warp, and
// `threadIdx.y` for warps, because putting everthing into a single
// `threadIdx.x` will make the expressions to complicated to solve.
Expand Down
1 change: 1 addition & 0 deletions test/70.program/test_program_with_micro_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,5 @@ def matmul(a: ft.Var[(M, K), "float16"], b: ft.Var[(K, N), "float16"]):
b_arr = ft.array(b_torch)
y_arr = exe(a_arr, b_arr)
y_torch = y_arr.torch()

assert torch.all(torch.isclose(y_torch, y_std, rtol=2e-2))

0 comments on commit d9d8b8b

Please sign in to comment.