Skip to content

Commit

Permalink
Add CUTLASS as an optional backend of schedule/as_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck committed Dec 17, 2023
1 parent d53b8d7 commit cf8ef27
Show file tree
Hide file tree
Showing 20 changed files with 446 additions and 134 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@
[submodule "3rd-party/range-v3"]
path = 3rd-party/range-v3
url = ../../ericniebler/range-v3.git
[submodule "3rd-party/cutlass"]
path = 3rd-party/cutlass
url = ../../NVIDIA/cutlass.git
1 change: 1 addition & 0 deletions 3rd-party/cutlass
Submodule cutlass added at a75b4a
16 changes: 14 additions & 2 deletions ffi/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,20 @@ void init_ffi_schedule(py::module_ &m) {
.def("vectorize", &Schedule::vectorize, "loop"_a)
.def("separate_tail", &Schedule::separateTail,
"noDuplicateVarDefs"_a = false)
.def("as_matmul", &Schedule::asMatMul, "loop"_a,
"mode"_a = AsMatMulMode::KeepMemLayout)
.def("as_matmul",
static_cast<void (Schedule::*)(
const ID &, AsMatMulMode, const Ref<Target> &, MatMulBackend)>(
&Schedule::asMatMul),
"loop"_a, "mode"_a, "target"_a, "backend"_a)
.def("as_matmul",
static_cast<void (Schedule::*)(const ID &, AsMatMulMode,
const Ref<Target> &)>(
&Schedule::asMatMul),
"loop"_a, "mode"_a, "target"_a)
.def("as_matmul",
static_cast<void (Schedule::*)(const ID &, AsMatMulMode)>(
&Schedule::asMatMul),
"loop"_a, "mode"_a = AsMatMulMode::KeepMemLayout)
.def("pluto_fuse", &Schedule::plutoFuse, "loop0"_a, "loop1"_a,
"nest_level_0"_a = 0, "nest_level_1"_a = 0,
"fusable_overlap_threshold"_a = 1,
Expand Down
13 changes: 13 additions & 0 deletions ffi/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ void init_ffi_ast_stmt(py::module_ &m) {
.def_readonly("tape_name", &MarkVersionNode::tapeName_)
.def_readonly("var", &MarkVersionNode::var_);

py::class_<MatMulBackend>(m, "MatMulBackend")
.def(py::init<MatMulBackend>())
.def(py::init(&parseMatMulBackend))
.def("__str__",
static_cast<std::string (*)(const MatMulBackend &)>(&toString))
.def("__hash__", [](MatMulBackend backend) { return (size_t)backend; })
.def("__eq__",
[](MatMulBackend lhs, MatMulBackend rhs) { return lhs == rhs; })
.def("__eq__", [](MatMulBackend lhs, const std::string &rhs) {
return lhs == parseMatMulBackend(rhs);
});
// no py::implicitly_convertible from str, because it fails silently

// makers
m.def("makeAny", []() { return makeAny(); });
m.def(
Expand Down
8 changes: 5 additions & 3 deletions include/codegen/code_gen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ class CodeGenCUDA : public CodeGenC<CodeGenCUDAStream> {
typedef CodeGenCUDAStream Stream;

private:
Ref<GPUTarget> target_;
std::string kernelPrefix_;
int nKernel_ = 0;
Expr sharedStackTop_ = makeIntConst(0);
Expr globalStackTop_ = makeIntConst(0);
Expr globalSize_ = makeIntConst(0);
std::unordered_set<Stmt> streamScopes_;
bool inCublas_ = false;
bool inMatmul_ = false;

public:
CodeGenCUDA(const std::vector<FuncParam> &params,
const std::vector<FuncRet> &returns,
const std::string &kernelPrefix)
: CodeGenC(params, returns), kernelPrefix_(kernelPrefix) {}
const Ref<GPUTarget> &target, const std::string &kernelPrefix)
: CodeGenC(params, returns), target_(target),
kernelPrefix_(kernelPrefix) {}

using CodeGenC<CodeGenCUDAStream>::genMdPtrType;
using CodeGenC<CodeGenCUDAStream>::genMdPtrDef;
Expand Down
2 changes: 1 addition & 1 deletion include/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ class Mutator {

virtual Stmt visit(const MatMul &op) {
return makeMatMul(
(*this)(op->a_), (*this)(op->b_), (*this)(op->c_),
op->backend_, (*this)(op->a_), (*this)(op->b_), (*this)(op->c_),
(*this)(op->alpha_), (*this)(op->beta_), (*this)(op->m_),
(*this)(op->k_), (*this)(op->n_), (*this)(op->lda_),
(*this)(op->ldb_), (*this)(op->ldc_), (*this)(op->stridea_),
Expand Down
7 changes: 7 additions & 0 deletions include/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,9 +725,16 @@ class Schedule {
* may affect performance of other use of these variable. `TryTranspose` =>
* try `cache` and then `var_reorder` on some variables, but will incur
* extra overhead.
* @param target : Hardware target. If omitted, use the default target in
* Config, or the target set by `with` scopes.
* @param backend : Backend library. Defaults to `Mkl` for CPU targets,
* `Cublas` for GPU targets.
* @throw InvalidSchedule if the loop cannot be transformed to be a matrix
* multiplication
*/
void asMatMul(const ID &loop, AsMatMulMode mode, const Ref<Target> &target,
MatMulBackend backend);
void asMatMul(const ID &loop, AsMatMulMode mode, const Ref<Target> &target);
void asMatMul(const ID &loop,
AsMatMulMode mode = AsMatMulMode::KeepMemLayout);

Expand Down
6 changes: 4 additions & 2 deletions include/schedule/as_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class AsMatMul : public SymbolTable<Mutator> {
typedef SymbolTable<Mutator> BaseClass;

ID loop_;
MatMulBackend backend_;

int nestCnt_ = 0;
std::unordered_map<std::string, int> iterMap_; // iter var -> nest cnt
Expand All @@ -74,7 +75,8 @@ class AsMatMul : public SymbolTable<Mutator> {
bool done_ = false;

public:
AsMatMul(const ID &loop) : loop_(loop) {}
AsMatMul(const ID &loop, MatMulBackend backend)
: loop_(loop), backend_(backend) {}

bool done() const { return done_; }

Expand Down Expand Up @@ -199,7 +201,7 @@ class AsMatMul : public SymbolTable<Mutator> {
Stmt visit(const VarDef &op) override;
};

Stmt asMatMul(const Stmt &ast, const ID &loop);
Stmt asMatMul(const Stmt &ast, const ID &loop, MatMulBackend backend);

} // namespace freetensor

Expand Down
47 changes: 44 additions & 3 deletions include/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,51 @@ Stmt makeEval(T &&expr, const Metadata &metadata = nullptr, const ID &id = {},
return e;
}

/**
* Backend library of MatMul nodes
*/
enum class MatMulBackend : size_t {
Mkl = 0,
Cublas,
Cutlass,
// ----------------------------
NumBackends
};

constexpr std::array matMulBackendNames = {
"mkl",
"cublas",
"cutlass",
};
static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes);

inline std::ostream &operator<<(std::ostream &os, MatMulBackend backend) {
return os << matMulBackendNames.at((size_t)backend);
}

inline MatMulBackend parseMatMulBackend(const std::string &_str) {
auto &&str = tolower(_str);
for (auto &&[i, s] : views::enumerate(matMulBackendNames)) {
if (s == str) {
return (MatMulBackend)i;
}
}
std::string msg = "Unrecognized MatMul backend \"" + _str +
"\". Candidates are (case-insensitive): ";
for (auto &&[i, s] : views::enumerate(matMulBackendNames)) {
msg += (i > 0 ? ", " : "");
msg += s;
}
ERROR(msg);
}

/**
* External call to a batched GEMM
*/
class MatMulNode : public StmtNode {
public:
MatMulBackend backend_;

// c_ = alpha_ * a_ * b_ + beta_ * c_
// a_ is an m_ * k_ matrix
// b_ is a k_ * n_ matrix
Expand Down Expand Up @@ -489,9 +529,9 @@ class MatMulNode : public StmtNode {
};
typedef Ref<MatMulNode> MatMul;
inline Stmt
makeMatMul(const Expr &a, const Expr &b, const Expr &c, const Expr &alpha,
const Expr &beta, const Expr &m, const Expr &k, const Expr &n,
const Expr &lda, const Expr &ldb, const Expr &ldc,
makeMatMul(MatMulBackend backend, const Expr &a, const Expr &b, const Expr &c,
const Expr &alpha, const Expr &beta, const Expr &m, const Expr &k,
const Expr &n, const Expr &lda, const Expr &ldb, const Expr &ldc,
const Expr &stridea, const Expr &strideb, const Expr &stridec,
const Expr &batchSize, bool aIsRowMajor, bool bIsRowMajor,
bool cIsRowMajor, const Stmt &equivalent,
Expand All @@ -500,6 +540,7 @@ makeMatMul(const Expr &a, const Expr &b, const Expr &c, const Expr &alpha,
MatMul s = MatMul::make();
s->metadata() = metadata;
s->setId(id);
s->backend_ = backend;
s->a_ = a;
s->b_ = b;
s->c_ = c;
Expand Down
27 changes: 25 additions & 2 deletions python/freetensor/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .jit import JITTemplate
from .meta import MemType
from .utils import as_decorator
from .driver import TargetType
from . import config


class IDMap:
Expand Down Expand Up @@ -779,7 +781,11 @@ def separate_tail(self, noDuplicateVarDefs=False):
"""
super().separate_tail(noDuplicateVarDefs)

def as_matmul(self, loop, mode: AsMatMulMode = AsMatMulMode.KeepMemLayout):
def as_matmul(self,
loop,
mode: AsMatMulMode = AsMatMulMode.KeepMemLayout,
target=None,
backend: Union[str, ffi.MatMulBackend] = None):
"""
Transform nested loops to be a external call to a matrix multiplication
Expand All @@ -795,13 +801,30 @@ def as_matmul(self, loop, mode: AsMatMulMode = AsMatMulMode.KeepMemLayout):
=> try `var_reorder` on some variables, but may affect performance of
other use of these variable. `TryTranspose` => try `cache` and then
`var_reorder` on some variables, but will incur extra overhead.
target : Target
Hardware target. If omitted, use the default target in config, or the
target set by `with` scopes.
backend : str, ffi.MatMulBackend
Backend library. Defaults to "mkl" for CPU targets, "cublas" for GPU
targets.
Raises
------
InvalidSchedule
if the loop cannot be transformed to be a matrix multiplication
"""
super().as_matmul(self._lookup(loop), mode)
if target is None:
target = config.default_target()
if backend is None:
if target.type() == TargetType.CPU:
backend = "mkl"
elif target.type() == TargetType.GPU:
backend = "cublas"
else:
raise ffi.InvalidSchedule(
"No default MatMul backend for target " + target)
super().as_matmul(self._lookup(loop), mode, target,
ffi.MatMulBackend(backend))

def pluto_fuse(self,
loop0,
Expand Down
2 changes: 2 additions & 0 deletions runtime/gpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ inline const char *cublasGetErrorString(cublasStatus_t error) {
return "<unknown>";
}

// checkCutlassError is defined in gpu_runtime.h

class GPUContext : public Context {
bool initialized_ = false;
cublasHandle_t cublas_;
Expand Down
13 changes: 13 additions & 0 deletions runtime/gpu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include <stdexcept>
#include <type_traits>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"

#include "gpu_context.h"

#include "mdspan.h"
Expand All @@ -20,6 +23,16 @@

#define checkCudaError(...) runtimeCheckCudaError(__VA_ARGS__)

#define checkCutlassError(call) \
{ \
auto err = (call); \
if (cutlass::Status::kSuccess != err) { \
fprintf(stderr, "Cutlass error in file '%s' in line %i : %s.\n", \
__FILE__, __LINE__, cutlassGetStatusString(err)); \
throw std::runtime_error("cutlass error"); \
} \
}

inline void *cudaNew(size_t size, cudaStream_t stream) {
void *ptr = nullptr;
if (size > 0) {
Expand Down
Loading

0 comments on commit cf8ef27

Please sign in to comment.