diff --git a/src/schedule/lower_cutlass_micro_block.cc b/src/schedule/lower_cutlass_micro_block.cc index 25bd530a6..af1a608ef 100644 --- a/src/schedule/lower_cutlass_micro_block.cc +++ b/src/schedule/lower_cutlass_micro_block.cc @@ -429,11 +429,11 @@ Stmt lowerCutlassMicroBlock(const Stmt &_ast, const ID &matMulId, ast = varSplit(ast, defIdC, nDimsCOthers + 0, VarSplitMode::FixedSize, -1, nWarpBatch); ast = varSplit(ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize, - 16, -1); + 8 * nWarpM, -1); ast = varSplit(ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize, -1, nWarpM); ast = varSplit(ast, defIdC, nDimsCOthers + 5, VarSplitMode::FixedSize, - 16, -1); + 8 * nWarpN, -1); ast = varSplit(ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize, -1, nWarpN); ast = varSplit(ast, defIdC, nDimsCOthers + 7, VarSplitMode::FixedSize, @@ -452,13 +452,13 @@ Stmt lowerCutlassMicroBlock(const Stmt &_ast, const ID &matMulId, ast = varSplit(ast, defIdC, nDimsCOthers + 0, VarSplitMode::FixedSize, -1, nWarpBatch); ast = varSplit(ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize, - 32, -1); + 16 * nWarpM, -1); ast = varSplit(ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize, -1, nWarpM); ast = varSplit(ast, defIdC, nDimsCOthers + 4, VarSplitMode::FixedSize, -1, 2); ast = varSplit(ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize, - 16, -1); + 8 * nWarpN, -1); ast = varSplit(ast, defIdC, nDimsCOthers + 7, VarSplitMode::FixedSize, -1, nWarpN); ast = varSplit(ast, defIdC, nDimsCOthers + 8, VarSplitMode::FixedSize,