diff --git a/.gitmodules b/.gitmodules index 3c98b9f79..909273d34 100644 --- a/.gitmodules +++ b/.gitmodules @@ -21,3 +21,6 @@ [submodule "3rd-party/range-v3"] path = 3rd-party/range-v3 url = ../../ericniebler/range-v3.git +[submodule "3rd-party/cutlass"] + path = 3rd-party/cutlass + url = ../../NVIDIA/cutlass.git diff --git a/3rd-party/cutlass b/3rd-party/cutlass new file mode 160000 index 000000000..a75b4ac48 --- /dev/null +++ b/3rd-party/cutlass @@ -0,0 +1 @@ +Subproject commit a75b4ac483166189a45290783cb0a18af5ff0ea5 diff --git a/ffi/schedule.cc b/ffi/schedule.cc index 4feff9be7..87bf5d684 100644 --- a/ffi/schedule.cc +++ b/ffi/schedule.cc @@ -135,8 +135,20 @@ void init_ffi_schedule(py::module_ &m) { .def("vectorize", &Schedule::vectorize, "loop"_a) .def("separate_tail", &Schedule::separateTail, "noDuplicateVarDefs"_a = false) - .def("as_matmul", &Schedule::asMatMul, "loop"_a, - "mode"_a = AsMatMulMode::KeepMemLayout) + .def("as_matmul", + static_cast &, MatMulBackend)>( + &Schedule::asMatMul), + "loop"_a, "mode"_a, "target"_a, "backend"_a) + .def("as_matmul", + static_cast &)>( + &Schedule::asMatMul), + "loop"_a, "mode"_a, "target"_a) + .def("as_matmul", + static_cast( + &Schedule::asMatMul), + "loop"_a, "mode"_a = AsMatMulMode::KeepMemLayout) .def("pluto_fuse", &Schedule::plutoFuse, "loop0"_a, "loop1"_a, "nest_level_0"_a = 0, "nest_level_1"_a = 0, "fusable_overlap_threshold"_a = 1, diff --git a/ffi/stmt.cc b/ffi/stmt.cc index 56652b6b4..5b4805b42 100644 --- a/ffi/stmt.cc +++ b/ffi/stmt.cc @@ -114,6 +114,19 @@ void init_ffi_ast_stmt(py::module_ &m) { .def_readonly("tape_name", &MarkVersionNode::tapeName_) .def_readonly("var", &MarkVersionNode::var_); + py::class_(m, "MatMulBackend") + .def(py::init()) + .def(py::init(&parseMatMulBackend)) + .def("__str__", + static_cast(&toString)) + .def("__hash__", [](MatMulBackend backend) { return (size_t)backend; }) + .def("__eq__", + [](MatMulBackend lhs, MatMulBackend rhs) { return lhs == rhs; }) + .def("__eq__", [](MatMulBackend lhs, const std::string &rhs) { + return lhs == parseMatMulBackend(rhs); + }); + // no py::implicitly_convertible from str, because it fails silently + // makers m.def("makeAny", []() { return makeAny(); }); m.def( diff --git a/include/codegen/code_gen_cuda.h b/include/codegen/code_gen_cuda.h index 7a28d3565..27ea9f7e5 100644 --- a/include/codegen/code_gen_cuda.h +++ b/include/codegen/code_gen_cuda.h @@ -20,19 +20,21 @@ class CodeGenCUDA : public CodeGenC { typedef CodeGenCUDAStream Stream; private: + Ref target_; std::string kernelPrefix_; int nKernel_ = 0; Expr sharedStackTop_ = makeIntConst(0); Expr globalStackTop_ = makeIntConst(0); Expr globalSize_ = makeIntConst(0); std::unordered_set streamScopes_; - bool inCublas_ = false; + bool inMatmul_ = false; public: CodeGenCUDA(const std::vector ¶ms, const std::vector &returns, - const std::string &kernelPrefix) - : CodeGenC(params, returns), kernelPrefix_(kernelPrefix) {} + const Ref &target, const std::string &kernelPrefix) + : CodeGenC(params, returns), target_(target), + kernelPrefix_(kernelPrefix) {} using CodeGenC::genMdPtrType; using CodeGenC::genMdPtrDef; diff --git a/include/mutator.h b/include/mutator.h index e781e2942..9b24d6af4 100644 --- a/include/mutator.h +++ b/include/mutator.h @@ -339,7 +339,7 @@ class Mutator { virtual Stmt visit(const MatMul &op) { return makeMatMul( - (*this)(op->a_), (*this)(op->b_), (*this)(op->c_), + op->backend_, (*this)(op->a_), (*this)(op->b_), (*this)(op->c_), (*this)(op->alpha_), (*this)(op->beta_), (*this)(op->m_), (*this)(op->k_), (*this)(op->n_), (*this)(op->lda_), (*this)(op->ldb_), (*this)(op->ldc_), (*this)(op->stridea_), diff --git a/include/schedule.h b/include/schedule.h index 4d61de8b1..b672bd9a2 100644 --- a/include/schedule.h +++ b/include/schedule.h @@ -725,9 +725,16 @@ class Schedule { * may affect performance of other use of these variable. `TryTranspose` => * try `cache` and then `var_reorder` on some variables, but will incur * extra overhead. + * @param target : Hardware target. If omitted, use the default target in + * Config, or the target set by `with` scopes. + * @param backend : Backend library. Defaults to `Mkl` for CPU targets, + * `Cublas` for GPU targets. * @throw InvalidSchedule if the loop cannot be transformed to be a matrix * multiplication */ + void asMatMul(const ID &loop, AsMatMulMode mode, const Ref &target, + MatMulBackend backend); + void asMatMul(const ID &loop, AsMatMulMode mode, const Ref &target); void asMatMul(const ID &loop, AsMatMulMode mode = AsMatMulMode::KeepMemLayout); diff --git a/include/schedule/as_matmul.h b/include/schedule/as_matmul.h index 575bb1056..5c4a49c2b 100644 --- a/include/schedule/as_matmul.h +++ b/include/schedule/as_matmul.h @@ -57,6 +57,7 @@ class AsMatMul : public SymbolTable { typedef SymbolTable BaseClass; ID loop_; + MatMulBackend backend_; int nestCnt_ = 0; std::unordered_map iterMap_; // iter var -> nest cnt @@ -74,7 +75,8 @@ class AsMatMul : public SymbolTable { bool done_ = false; public: - AsMatMul(const ID &loop) : loop_(loop) {} + AsMatMul(const ID &loop, MatMulBackend backend) + : loop_(loop), backend_(backend) {} bool done() const { return done_; } @@ -199,7 +201,7 @@ class AsMatMul : public SymbolTable { Stmt visit(const VarDef &op) override; }; -Stmt asMatMul(const Stmt &ast, const ID &loop); +Stmt asMatMul(const Stmt &ast, const ID &loop, MatMulBackend backend); } // namespace freetensor diff --git a/include/stmt.h b/include/stmt.h index 0322a1184..55722cda3 100644 --- a/include/stmt.h +++ b/include/stmt.h @@ -456,11 +456,51 @@ Stmt makeEval(T &&expr, const Metadata &metadata = nullptr, const ID &id = {}, return e; } +/** + * Backend library of MatMul nodes + */ +enum class MatMulBackend : size_t { + Mkl = 0, + Cublas, + Cutlass, + // ---------------------------- + NumBackends +}; + +constexpr std::array matMulBackendNames = { + "mkl", + "cublas", + "cutlass", +}; +static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes); + +inline std::ostream &operator<<(std::ostream &os, MatMulBackend backend) { + return os << matMulBackendNames.at((size_t)backend); +} + +inline MatMulBackend parseMatMulBackend(const std::string &_str) { + auto &&str = tolower(_str); + for (auto &&[i, s] : views::enumerate(matMulBackendNames)) { + if (s == str) { + return (MatMulBackend)i; + } + } + std::string msg = "Unrecognized MatMul backend \"" + _str + + "\". Candidates are (case-insensitive): "; + for (auto &&[i, s] : views::enumerate(matMulBackendNames)) { + msg += (i > 0 ? ", " : ""); + msg += s; + } + ERROR(msg); +} + /** * External call to a batched GEMM */ class MatMulNode : public StmtNode { public: + MatMulBackend backend_; + // c_ = alpha_ * a_ * b_ + beta_ * c_ // a_ is an m_ * k_ matrix // b_ is a k_ * n_ matrix @@ -489,9 +529,9 @@ class MatMulNode : public StmtNode { }; typedef Ref MatMul; inline Stmt -makeMatMul(const Expr &a, const Expr &b, const Expr &c, const Expr &alpha, - const Expr &beta, const Expr &m, const Expr &k, const Expr &n, - const Expr &lda, const Expr &ldb, const Expr &ldc, +makeMatMul(MatMulBackend backend, const Expr &a, const Expr &b, const Expr &c, + const Expr &alpha, const Expr &beta, const Expr &m, const Expr &k, + const Expr &n, const Expr &lda, const Expr &ldb, const Expr &ldc, const Expr &stridea, const Expr &strideb, const Expr &stridec, const Expr &batchSize, bool aIsRowMajor, bool bIsRowMajor, bool cIsRowMajor, const Stmt &equivalent, @@ -500,6 +540,7 @@ makeMatMul(const Expr &a, const Expr &b, const Expr &c, const Expr &alpha, MatMul s = MatMul::make(); s->metadata() = metadata; s->setId(id); + s->backend_ = backend; s->a_ = a; s->b_ = b; s->c_ = c; diff --git a/python/freetensor/core/schedule.py b/python/freetensor/core/schedule.py index dc264105a..3de272bce 100644 --- a/python/freetensor/core/schedule.py +++ b/python/freetensor/core/schedule.py @@ -16,6 +16,8 @@ from .jit import JITTemplate from .meta import MemType from .utils import as_decorator +from .driver import TargetType +from . import config class IDMap: @@ -779,7 +781,11 @@ def separate_tail(self, noDuplicateVarDefs=False): """ super().separate_tail(noDuplicateVarDefs) - def as_matmul(self, loop, mode: AsMatMulMode = AsMatMulMode.KeepMemLayout): + def as_matmul(self, + loop, + mode: AsMatMulMode = AsMatMulMode.KeepMemLayout, + target=None, + backend: Union[str, ffi.MatMulBackend] = None): """ Transform nested loops to be a external call to a matrix multiplication @@ -795,13 +801,30 @@ def as_matmul(self, loop, mode: AsMatMulMode = AsMatMulMode.KeepMemLayout): => try `var_reorder` on some variables, but may affect performance of other use of these variable. `TryTranspose` => try `cache` and then `var_reorder` on some variables, but will incur extra overhead. + target : Target + Hardware target. If omitted, use the default target in config, or the + target set by `with` scopes. + backend : str, ffi.MatMulBackend + Backend library. Defaults to "mkl" for CPU targets, "cublas" for GPU + targets. Raises ------ InvalidSchedule if the loop cannot be transformed to be a matrix multiplication """ - super().as_matmul(self._lookup(loop), mode) + if target is None: + target = config.default_target() + if backend is None: + if target.type() == TargetType.CPU: + backend = "mkl" + elif target.type() == TargetType.GPU: + backend = "cublas" + else: + raise ffi.InvalidSchedule( + "No default MatMul backend for target " + target) + super().as_matmul(self._lookup(loop), mode, target, + ffi.MatMulBackend(backend)) def pluto_fuse(self, loop0, diff --git a/runtime/gpu_context.h b/runtime/gpu_context.h index 00a2ee26c..ba175ace4 100644 --- a/runtime/gpu_context.h +++ b/runtime/gpu_context.h @@ -52,6 +52,8 @@ inline const char *cublasGetErrorString(cublasStatus_t error) { return ""; } +// checkCutlassError is defined in gpu_runtime.h + class GPUContext : public Context { bool initialized_ = false; cublasHandle_t cublas_; diff --git a/runtime/gpu_runtime.h b/runtime/gpu_runtime.h index 8e45f086c..d311130d5 100644 --- a/runtime/gpu_runtime.h +++ b/runtime/gpu_runtime.h @@ -9,6 +9,9 @@ #include #include +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + #include "gpu_context.h" #include "mdspan.h" @@ -20,6 +23,16 @@ #define checkCudaError(...) runtimeCheckCudaError(__VA_ARGS__) +#define checkCutlassError(call) \ + { \ + auto err = (call); \ + if (cutlass::Status::kSuccess != err) { \ + fprintf(stderr, "Cutlass error in file '%s' in line %i : %s.\n", \ + __FILE__, __LINE__, cutlassGetStatusString(err)); \ + throw std::runtime_error("cutlass error"); \ + } \ + } + inline void *cudaNew(size_t size, cudaStream_t stream) { void *ptr = nullptr; if (size > 0) { diff --git a/src/codegen/code_gen_cpu.cc b/src/codegen/code_gen_cpu.cc index 13ecde671..db81e15b5 100644 --- a/src/codegen/code_gen_cpu.cc +++ b/src/codegen/code_gen_cpu.cc @@ -324,21 +324,6 @@ void CodeGenCPU::visit(const For &op) { } void CodeGenCPU::visit(const MatMul &op) { -#ifdef FT_WITH_MKL - makeIndent(); - if (inParallel_) { - os() << "mkl_set_num_threads_local(1);" << std::endl; - // TODO: set it to max(1, cpu_count / outer_threads_count) - } else { - os() << "mkl_set_num_threads_local(0); // 0 == reset" << std::endl; - } - - auto d = op->c_->dtype(); - if (op->a_->dtype() != d || op->b_->dtype() != d) { - throw InvalidProgram( - "MKL requires all matrices have the same data type"); - } - bool transA = !op->aIsRowMajor_, transB = !op->bIsRowMajor_; Expr a = op->a_, b = op->b_, c = op->c_; Expr m = op->m_, k = op->k_, n = op->n_; @@ -354,44 +339,68 @@ void CodeGenCPU::visit(const MatMul &op) { std::swap(n, m); } - makeIndent(); - os() << "cblas_" << genMKLTypeMark(d) - << "gemm_batch_strided(CblasRowMajor, " - << (transA ? "CblasTrans" : "CblasNoTrans") << ", " - << (transB ? "CblasTrans" : "CblasNoTrans") << ", "; - (*this)(m); - os() << ", "; - (*this)(n); - os() << ", "; - (*this)(k); - os() << ", "; - (*this)(op->alpha_); - os() << ", &"; - (*this)(a); - os() << ", "; - (*this)(lda); - os() << ", "; - (*this)(stridea); - os() << ", &"; - (*this)(b); - os() << ", "; - (*this)(ldb); - os() << ", "; - (*this)(strideb); - os() << ", "; - (*this)(op->beta_); - os() << ", &"; - (*this)(c); - os() << ", "; - (*this)(ldc); - os() << ", "; - (*this)(stridec); - os() << ", "; - (*this)(op->batchSize_); - os() << ");" << std::endl; + switch (op->backend_) { + case MatMulBackend::Mkl: { +#ifdef FT_WITH_MKL + makeIndent(); + if (inParallel_) { + os() << "mkl_set_num_threads_local(1);" << std::endl; + // TODO: set it to max(1, cpu_count / outer_threads_count) + } else { + os() << "mkl_set_num_threads_local(0); // 0 == reset" << std::endl; + } + + auto d = op->c_->dtype(); + if (op->a_->dtype() != d || op->b_->dtype() != d) { + throw InvalidProgram( + "MKL requires all matrices have the same data type"); + } + + makeIndent(); + os() << "cblas_" << genMKLTypeMark(d) + << "gemm_batch_strided(CblasRowMajor, " + << (transA ? "CblasTrans" : "CblasNoTrans") << ", " + << (transB ? "CblasTrans" : "CblasNoTrans") << ", "; + (*this)(m); + os() << ", "; + (*this)(n); + os() << ", "; + (*this)(k); + os() << ", "; + (*this)(op->alpha_); + os() << ", &"; + (*this)(a); + os() << ", "; + (*this)(lda); + os() << ", "; + (*this)(stridea); + os() << ", &"; + (*this)(b); + os() << ", "; + (*this)(ldb); + os() << ", "; + (*this)(strideb); + os() << ", "; + (*this)(op->beta_); + os() << ", &"; + (*this)(c); + os() << ", "; + (*this)(ldc); + os() << ", "; + (*this)(stridec); + os() << ", "; + (*this)(op->batchSize_); + os() << ");" << std::endl; #else - ERROR("Configuring with MKL is needed"); + ERROR("Configuring with MKL is needed"); #endif + break; + } + default: + throw InvalidProgram("MatMul backend " + + freetensor::toString(op->backend_) + + " is not supported for CPU"); + } } NativeCode codeGenCPU(const Func &func, const Ref &target) { diff --git a/src/codegen/code_gen_cuda.cc b/src/codegen/code_gen_cuda.cc index c6f717f13..44d17b394 100644 --- a/src/codegen/code_gen_cuda.cc +++ b/src/codegen/code_gen_cuda.cc @@ -28,6 +28,16 @@ static std::string genCUBLASType(DataType dtype) { } } +static bool canUseTensorCore(const Ref &target, DataType dtypeA, + DataType dtypeB, DataType dtypeC) { + // TODO: fp16 is supported after sm70 + if (target->computeCapability().first >= 8 && dtypeA == DataType::Float64 && + dtypeB == DataType::Float64 && dtypeC == DataType::Float64) { + return true; + } + return false; +} + std::function CodeGenCUDA::genMdPtrType(const VarDef &def, bool isConst) { Ref buf = def->buffer_; @@ -127,7 +137,7 @@ void CodeGenCUDA::genScalar(const VarDef &def, } bool CodeGenCUDA::inKernel() const { - return streamStack_.back().name_ != "default" || inCublas_; + return streamStack_.back().name_ != "default" || inMatmul_; } void CodeGenCUDA::exprOr1(const std::unordered_map &dict, @@ -751,77 +761,151 @@ void CodeGenCUDA::visit(const VarDef &op) { } void CodeGenCUDA::visit(const MatMul &op) { - if (inKernel()) { - throw InvalidProgram("External call to a matrix multiplication from " - "inside a CUDA kernel is not supported"); - } - - inCublas_ = true; + bool thisOpInKernel = inKernel(); + inMatmul_ = true; - bool transA = !op->aIsRowMajor_, transB = !op->bIsRowMajor_; + bool transA = !op->aIsRowMajor_, transB = !op->bIsRowMajor_, + transC = !op->cIsRowMajor_; Expr a = op->a_, b = op->b_, c = op->c_; Expr m = op->m_, k = op->k_, n = op->n_; Expr lda = op->lda_, ldb = op->ldb_, ldc = op->ldc_; Expr stridea = op->stridea_, strideb = op->strideb_, stridec = op->stridec_; - if (op->cIsRowMajor_) { - transA = !transA; - transB = !transB; - std::swap(transA, transB); - std::swap(a, b); - std::swap(lda, ldb); - std::swap(stridea, strideb); - std::swap(n, m); + + switch (op->backend_) { + case MatMulBackend::Cublas: { + if (thisOpInKernel) { + throw InvalidProgram("External call to a matrix multiplication " + "implemented by cuBLAS from inside a CUDA " + "kernel is not supported"); + } + + if (op->cIsRowMajor_) { + transA = !transA; + transB = !transB; + transC = false; + std::swap(transA, transB); + std::swap(a, b); + std::swap(lda, ldb); + std::swap(stridea, strideb); + std::swap(n, m); + } + + makeIndent(); + beginBlock(); + makeIndent(); + os() << gen(op->c_->dtype()) << " cublasAlpha = "; + (*this)(op->alpha_); + os() << ", cublasBeta = "; + (*this)(op->beta_); + os() << ";" << std::endl; + makeIndent(); + os() << "cublasGemmStridedBatchedEx(ctx->cublas(), " + << (transA ? "CUBLAS_OP_N" : "CUBLAS_OP_T") << ", " + << (transB ? "CUBLAS_OP_N" : "CUBLAS_OP_T") << ", "; + (*this)(m); + os() << ", "; + (*this)(n); + os() << ", "; + (*this)(k); + os() << ", &cublasAlpha, &"; + (*this)(a); + os() << ", " << genCUBLASType(op->a_->dtype()) << ", "; + (*this)(lda); + os() << ", "; + (*this)(stridea); + os() << ", &"; + (*this)(b); + os() << ", " << genCUBLASType(op->b_->dtype()) << ", "; + (*this)(ldb); + os() << ", "; + (*this)(strideb); + os() << ", &cublasBeta, &"; + (*this)(c); + os() << ", " << genCUBLASType(op->c_->dtype()) << ", "; + (*this)(ldc); + os() << ", "; + (*this)(stridec); + os() << ", "; + (*this)(op->batchSize_); + os() << ", " << genCUBLASType(op->c_->dtype()) + << ", CUBLAS_GEMM_DEFAULT);" << std::endl; + endBlock(); + break; } - makeIndent(); - beginBlock(); - makeIndent(); - os() << gen(op->c_->dtype()) << " cublasAlpha = "; - (*this)(op->alpha_); - os() << ", cublasBeta = "; - (*this)(op->beta_); - os() << ";" << std::endl; - makeIndent(); - os() << "cublasGemmStridedBatchedEx(ctx->cublas(), " - << (transA ? "CUBLAS_OP_N" : "CUBLAS_OP_T") << ", " - << (transB ? "CUBLAS_OP_N" : "CUBLAS_OP_T") << ", "; - (*this)(m); - os() << ", "; - (*this)(n); - os() << ", "; - (*this)(k); - os() << ", &cublasAlpha, &"; - (*this)(a); - os() << ", " << genCUBLASType(op->a_->dtype()) << ", "; - (*this)(lda); - os() << ", "; - (*this)(stridea); - os() << ", &"; - (*this)(b); - os() << ", " << genCUBLASType(op->b_->dtype()) << ", "; - (*this)(ldb); - os() << ", "; - (*this)(strideb); - os() << ", &cublasBeta, &"; - (*this)(c); - os() << ", " << genCUBLASType(op->c_->dtype()) << ", "; - (*this)(ldc); - os() << ", "; - (*this)(stridec); - os() << ", "; - (*this)(op->batchSize_); - os() << ", " << genCUBLASType(op->c_->dtype()) << ", CUBLAS_GEMM_DEFAULT);" - << std::endl; - endBlock(); + case MatMulBackend::Cutlass: { + if (thisOpInKernel) { + throw InvalidProgram("External call to a matrix multiplication " + "implemented by CUTLASS from inside a CUDA " + "kernel is not supported"); + } - inCublas_ = false; + makeIndent(); + os() << "cutlass::gemm::device::Gemm<" << gen(op->a_->dtype()) << ", " + << (transA ? "cutlass::layout::ColumnMajor" + : "cutlass::layout::RowMajor") + << ", " << gen(op->b_->dtype()) << ", " + << (transB ? "cutlass::layout::ColumnMajor" + : "cutlass::layout::RowMajor") + << ", " << gen(op->c_->dtype()) << ", " + << (transC ? "cutlass::layout::ColumnMajor" + : "cutlass::layout::RowMajor") + << ", " << gen(op->c_->dtype()) // TODO: accumulator type + << ", " + << (canUseTensorCore(target_, op->a_->dtype(), op->b_->dtype(), + op->c_->dtype()) + ? "cutlass::arch::OpClassTensorOp" + : "cutlass::arch::OpClassSimt") + << ", FT_CUTLASS_ARCH> gemm;" << std::endl; + makeIndent(); + os() << "checkCutlassError(gemm({{"; + (*this)(m); + os() << ", "; + (*this)(n); + os() << ", "; + (*this)(k); + os() << "}, {&"; + (*this)(a); + os() << ", "; + (*this)(lda); + os() << "}, {&"; + (*this)(b); + os() << ", "; + (*this)(ldb); + os() << "}, {&"; + (*this)(c); + os() << ", "; + (*this)(ldc); + os() << "}, {&"; + (*this)(c); + os() << ", "; + (*this)(ldc); + os() << "}, {"; + (*this)(op->alpha_); + os() << ", "; + (*this)(op->beta_); + os() << "}}, nullptr, __stream));" << std::endl; + break; + } + + default: + inMatmul_ = false; + throw InvalidProgram("MatMul backend " + + freetensor::toString(op->backend_) + + " is not supported for GPU"); + } + + inMatmul_ = false; } -NativeCode codeGenCUDA(const Func &func, const Ref &target) { +NativeCode codeGenCUDA(const Func &func, const Ref &_target) { + ASSERT(_target->type() == TargetType::GPU); + auto target = _target.as(); + auto prefix = mangle(func->name_); auto nParams = func->params_.size(); - CodeGenCUDA visitor(func->params_, func->returns_, prefix); + CodeGenCUDA visitor(func->params_, func->returns_, target, prefix); auto &&op = func->body_; visitor.beginBlock(); visitor(op); diff --git a/src/driver.cc b/src/driver.cc index 842a09059..7c61b7ced 100644 --- a/src/driver.cc +++ b/src/driver.cc @@ -268,6 +268,9 @@ void Driver::buildAndLoad() { addCommand(Config::backendCompilerNVCC().front()); for (auto &&path : Config::runtimeDir()) { addArgs("-I" + (std::string)path); + // CUTLASS requires include path. It cannot be included via relative + // path. + addArgs("-I" + (std::string)path + "/../3rd-party/cutlass/include"); } addArgs("-std=c++17", "-shared", "-Xcompiler", "-fPIC,-Wall,-O3", "--expt-relaxed-constexpr" /* required by mdspan */); @@ -279,6 +282,8 @@ void Driver::buildAndLoad() { auto cc = dev_->target().as()->computeCapability(); addArgs("-arch", "sm_" + std::to_string(cc.first) + std::to_string(cc.second)); + addArgs("-DFT_CUTLASS_ARCH=cutlass::arch::Sm" + + std::to_string(cc.first) + std::to_string(cc.second)); if (Config::debugBinary()) { addArgs("-g"); } diff --git a/src/hash.cc b/src/hash.cc index 6a5b5e2f6..9bfd5002c 100644 --- a/src/hash.cc +++ b/src/hash.cc @@ -150,6 +150,7 @@ size_t Hasher::compHash(const EvalNode &op) { size_t Hasher::compHash(const MatMulNode &op) { size_t h = ((size_t)op.nodeType() * K1 + B1) % P; + h = ((h + std::hash()(op.backend_)) * K2 + B2) % P; h = ((h + op.equivalent_->hash()) * K2 + B2) % P; return (h * K3 + B3) % P; } @@ -412,6 +413,9 @@ bool HashComparator::compare(const Eval &lhs, const Eval &rhs) const { } bool HashComparator::compare(const MatMul &lhs, const MatMul &rhs) const { + if (lhs->backend_ != rhs->backend_) { + return false; + } return (*this)(lhs->equivalent_, rhs->equivalent_); } diff --git a/src/schedule/as_matmul.cc b/src/schedule/as_matmul.cc index 589e6694a..bdb35d5a9 100644 --- a/src/schedule/as_matmul.cc +++ b/src/schedule/as_matmul.cc @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -201,9 +202,9 @@ Stmt AsMatMul::visit(const For &op) { } else { beta = makeIntConst(1); } - ret = makeMatMul(a_, b_, c_, alpha, beta, m_, k_, n_, lda_, ldb_, ldc_, - stridea_, strideb_, stridec_, batchSize_, aIsRowMajor_, - bIsRowMajor_, cIsRowMajor_, ret); + ret = makeMatMul(backend_, a_, b_, c_, alpha, beta, m_, k_, n_, lda_, + ldb_, ldc_, stridea_, strideb_, stridec_, batchSize_, + aIsRowMajor_, bIsRowMajor_, cIsRowMajor_, ret); for (auto &&def : innerDefs_) { ret = makeVarDef(def->name_, def->buffer_, def->viewOf_, ret, def->pinned_, def->metadata(), def->id()); @@ -433,8 +434,8 @@ Stmt AsMatMul::visit(const VarDef &op) { } } -Stmt asMatMul(const Stmt &_ast, const ID &loop) { - AsMatMul mutator(loop); +Stmt asMatMul(const Stmt &_ast, const ID &loop, MatMulBackend backend) { + AsMatMul mutator(loop, backend); auto ast = simplify(_ast); // Simplify confusing loop range and indexing // from libop. TODO: simplify only needed region ast = mutator(ast); @@ -444,11 +445,12 @@ Stmt asMatMul(const Stmt &_ast, const ID &loop) { return ast; } -void Schedule::asMatMul(const ID &loop, AsMatMulMode mode) { +void Schedule::asMatMul(const ID &loop, AsMatMulMode mode, + const Ref &target, MatMulBackend backend) { beginTransaction(); while (true) { - auto log = - appendLog(MAKE_SCHEDULE_LOG(AsMatMul, freetensor::asMatMul, loop)); + auto log = appendLog( + MAKE_SCHEDULE_LOG(AsMatMul, freetensor::asMatMul, loop, backend)); try { applyLog(log); break; @@ -483,4 +485,23 @@ void Schedule::asMatMul(const ID &loop, AsMatMulMode mode) { commitTransaction(); } +void Schedule::asMatMul(const ID &loop, AsMatMulMode mode, + const Ref &target) { + switch (target->type()) { + case TargetType::CPU: + asMatMul(loop, mode, target, MatMulBackend::Mkl); + break; + case TargetType::GPU: + asMatMul(loop, mode, target, MatMulBackend::Cutlass); + break; + default: + throw InvalidSchedule(ast(), "No default MatMul backend for target " + + toString(target)); + } +} + +void Schedule::asMatMul(const ID &loop, AsMatMulMode mode) { + asMatMul(loop, mode, Config::defaultTarget()); +} + } // namespace freetensor diff --git a/src/schedule/auto_use_lib.cc b/src/schedule/auto_use_lib.cc index 041c70bd8..2e6d2e0e4 100644 --- a/src/schedule/auto_use_lib.cc +++ b/src/schedule/auto_use_lib.cc @@ -10,7 +10,7 @@ void Schedule::autoUseLib(const Ref &target) { // Suppose the root node is not . It should be auto loop = _loop.as(); try { - asMatMul(loop->id(), AsMatMulMode::TryTranspose); + asMatMul(loop->id(), AsMatMulMode::TryTranspose, target); } catch (const InvalidSchedule &e) { // If the loop is marked as preferLibs, we inline all local // variables, fission all the statments apart, and try applying to @@ -50,7 +50,7 @@ void Schedule::autoUseLib(const Ref &target) { fission(loop->id(), FissionSide::After, stmt->id(), true, "." + toString(i) + ".lib", "") .first.at(loop->id()); - asMatMul(libStmtId, AsMatMulMode::TryTranspose); + asMatMul(libStmtId, AsMatMulMode::TryTranspose, target); commitTransaction(); } catch (const InvalidSchedule &e) { abortTransaction(); diff --git a/test/40.codegen/gpu/test_gpu_cublas.py b/test/40.codegen/gpu/test_gpu_cublas.py index 1f553bdbe..da98857b1 100644 --- a/test/40.codegen/gpu/test_gpu_cublas.py +++ b/test/40.codegen/gpu/test_gpu_cublas.py @@ -24,7 +24,7 @@ def test(a, b, c): c[i, j] += a[i, k] * b[k, j] s = ft.Schedule(test) - s.as_matmul("L1") + s.as_matmul("L1", ft.AsMatMulMode.KeepMemLayout, target, "cublas") func = ft.lower(s.func(), target, verbose=1) code = ft.codegen(func, target, verbose=True) assert "cublas" in code.code diff --git a/test/40.codegen/gpu/test_gpu_cutlass.py b/test/40.codegen/gpu/test_gpu_cutlass.py new file mode 100644 index 000000000..7215f330f --- /dev/null +++ b/test/40.codegen/gpu/test_gpu_cutlass.py @@ -0,0 +1,70 @@ +import freetensor as ft +from freetensor import debug +import pytest +import numpy as np + +if not ft.with_cuda(): + pytest.skip("requires CUDA", allow_module_level=True) + +device = ft.GPU() +target = device.target() + + +def test_fp64(): + + @ft.transform + def test(a, b, c): + a: ft.Var[(48, 64), "float64", "input", "gpu/global"] + b: ft.Var[(64, 72), "float64", "input", "gpu/global"] + c: ft.Var[(48, 72), "float64", "inout", "gpu/global"] + #! label: L1 + for i in range(48): + for j in range(72): + for k in range(64): + c[i, j] += a[i, k] * b[k, j] + + s = ft.Schedule(test) + s.as_matmul("L1", ft.AsMatMulMode.KeepMemLayout, target, "cutlass") + func = ft.lower(s.func(), target, verbose=1) + code = ft.codegen(func, target, verbose=True) + assert "cutlass" in code.code + a_np = np.random.uniform(size=(48, 64)).astype("float64") + b_np = np.random.uniform(size=(64, 72)).astype("float64") + c_np = np.random.uniform(size=(48, 72)).astype("float64") + a_arr = ft.Array(a_np) + b_arr = ft.Array(b_np) + c_arr = ft.Array(c_np.copy()) + ft.build_binary(code, device)(a=a_arr, b=b_arr, c=c_arr) + c_result = c_arr.numpy() + + assert np.all(np.isclose(c_result, c_np + a_np @ b_np)) + + +def test_fp32(): + + @ft.transform + def test(a, b, c): + a: ft.Var[(48, 64), "float32", "input", "gpu/global"] + b: ft.Var[(64, 72), "float32", "input", "gpu/global"] + c: ft.Var[(48, 72), "float32", "inout", "gpu/global"] + #! label: L1 + for i in range(48): + for j in range(72): + for k in range(64): + c[i, j] += a[i, k] * b[k, j] + + s = ft.Schedule(test) + s.as_matmul("L1", ft.AsMatMulMode.KeepMemLayout, target, "cutlass") + func = ft.lower(s.func(), target, verbose=1) + code = ft.codegen(func, target, verbose=True) + assert "cutlass" in code.code + a_np = np.random.uniform(size=(48, 64)).astype("float32") + b_np = np.random.uniform(size=(64, 72)).astype("float32") + c_np = np.random.uniform(size=(48, 72)).astype("float32") + a_arr = ft.Array(a_np) + b_arr = ft.Array(b_np) + c_arr = ft.Array(c_np.copy()) + ft.build_binary(code, device)(a=a_arr, b=b_arr, c=c_arr) + c_result = c_arr.numpy() + + assert np.all(np.isclose(c_result, c_np + a_np @ b_np))