Skip to content

Commit

Permalink
Add schedule/parallelize_as (#573)
Browse files Browse the repository at this point in the history
* Detect strides in pass/shrink_for

* Minor improvement to analyze/deps + Update CI script for current cluster config

* Add schedule/parallelize_as
  • Loading branch information
roastduck authored Jan 10, 2024
1 parent f04ff6e commit 2811fe4
Show file tree
Hide file tree
Showing 25 changed files with 796 additions and 95 deletions.
41 changes: 31 additions & 10 deletions ffi/parallel_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,40 @@ void init_ffi_parallel_scope(py::module_ &m) {
.def(py::init<>())
.def("__str__",
[](const SerialScope &scope) { return toString(scope); })
.def("__eq__", [](const SerialScope &lhs, const SerialScope &rhs) {
return lhs == rhs;
});
.def("__eq__", [](const SerialScope &lhs,
const SerialScope &rhs) { return lhs == rhs; })
.def("__eq__",
[](const SerialScope &lhs, const std::string &rhs) {
return ParallelScope{lhs} == parseParallelScope(rhs);
})
.def("__eq__",
[](const SerialScope &lhs, py::object rhs) { return false; });

py::class_<OpenMPScope>(m, "OpenMPScope")
.def(py::init<>())
.def("__str__",
[](const OpenMPScope &scope) { return toString(scope); })
.def("__eq__", [](const OpenMPScope &lhs, const OpenMPScope &rhs) {
return lhs == rhs;
});
.def("__eq__", [](const OpenMPScope &lhs,
const OpenMPScope &rhs) { return lhs == rhs; })
.def("__eq__",
[](const OpenMPScope &lhs, const std::string &rhs) {
return ParallelScope{lhs} == parseParallelScope(rhs);
})
.def("__eq__",
[](const OpenMPScope &lhs, py::object rhs) { return false; });

py::class_<CUDAStreamScope>(m, "CUDAStreamScope")
.def(py::init<>())
.def("__str__",
[](const CUDAStreamScope &scope) { return toString(scope); })
.def("__eq__", [](const CUDAStreamScope &lhs,
const CUDAStreamScope &rhs) { return lhs == rhs; });
const CUDAStreamScope &rhs) { return lhs == rhs; })
.def("__eq__",
[](const CUDAStreamScope &lhs, const std::string &rhs) {
return ParallelScope{lhs} == parseParallelScope(rhs);
})
.def("__eq__",
[](const CUDAStreamScope &lhs, py::object rhs) { return false; });

py::enum_<CUDAScope::Level>(m, "CUDAScopeLevel")
.value("Block", CUDAScope::Level::Block)
Expand All @@ -40,9 +56,14 @@ void init_ffi_parallel_scope(py::module_ &m) {
return CUDAScope{level, dim};
}))
.def("__str__", [](const CUDAScope &scope) { return toString(scope); })
.def("__eq__", [](const CUDAScope &lhs, const CUDAScope &rhs) {
return lhs == rhs;
});
.def("__eq__", [](const CUDAScope &lhs,
const CUDAScope &rhs) { return lhs == rhs; })
.def("__eq__",
[](const CUDAScope &lhs, const std::string &rhs) {
return ParallelScope{lhs} == parseParallelScope(rhs);
})
.def("__eq__",
[](const CUDAScope &lhs, py::object rhs) { return false; });

// Factory function, used as a class
m.def("ParallelScope", &parseParallelScope);
Expand Down
8 changes: 4 additions & 4 deletions ffi/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ void init_ffi_pass(py::module_ &m) {
"stmt"_a);

m.def("shrink_for",
static_cast<Func (*)(const Func &, const Stmt &, const bool &)>(
static_cast<Func (*)(const Func &, const ID &, const bool &)>(
&shrinkFor),
"func"_a, "sub_ast"_a = nullptr, "do_simplify"_a = true);
"func"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true);
m.def("shrink_for",
static_cast<Stmt (*)(const Stmt &, const Stmt &, bool)>(&shrinkFor),
"stmt"_a, "sub_ast"_a = nullptr, "do_simplify"_a = true);
static_cast<Stmt (*)(const Stmt &, const ID &, bool)>(&shrinkFor),
"stmt"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true);

m.def("merge_and_hoist_if",
static_cast<Func (*)(const Func &)>(&mergeAndHoistIf), "func"_a);
Expand Down
2 changes: 2 additions & 0 deletions ffi/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ void init_ffi_schedule(py::module_ &m) {
.def("inline", &Schedule::inlining, "vardef"_a)
.def("parallelize", &Schedule::parallelize, "loop"_a, "parallel"_a,
"allow_reduction"_a = true)
.def("parallelize_as", &Schedule::parallelizeAs, "nest"_a,
"reference"_a, "def_id"_a)
.def("unroll", &Schedule::unroll, "loop"_a, "immedate"_a = false)
.def("vectorize", &Schedule::vectorize, "loop"_a)
.def("separate_tail", &Schedule::separateTail,
Expand Down
7 changes: 5 additions & 2 deletions include/analyze/comp_unique_bounds_pb.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace freetensor {
class CompUniqueBoundsPB : public CompUniqueBounds {
public:
class Bound : public CompUniqueBounds::Bound {
public: // Visible to CompUniqueBoundsPB's subclasses
Ref<PBCtx> ctx_;
// isl var -> ft expr, the demangling map yielded from GenPBExpr
// shared from CompUniqueBoundsPB::cachedFreeVars_
Expand All @@ -28,8 +29,6 @@ class CompUniqueBoundsPB : public CompUniqueBounds {
// single output being the bounded expression
PBSet bound_;

friend class CompUniqueBoundsPB;

public:
Bound(Ref<PBCtx> ctx,
Ref<std::unordered_map<std::string, Expr>> demangleMap,
Expand Down Expand Up @@ -63,6 +62,10 @@ class CompUniqueBoundsPB : public CompUniqueBounds {
Ref<std::unordered_map<std::string, Expr>> cachedFreeVars_;
std::unordered_map<Expr, Ref<Bound>> cachedValues_;

protected:
Ref<CompUniqueBoundsPB::Bound>
unionBoundsAsBound(const std::vector<Ref<CompUniqueBounds::Bound>> &bounds);

public:
CompUniqueBoundsPB(const CompTransientBoundsInterface &transients)
: CompUniqueBounds(transients), transients_(transients),
Expand Down
19 changes: 14 additions & 5 deletions include/analyze/deps.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct AccessPoint {
Ref<Buffer> buffer_;
int defAxis_; /// The position of the VarDef
std::vector<IterAxis> iter_; /// The temporal location of the access
std::vector<Expr> access_; /// The spacial location of the access
std::vector<Expr> access_; /// The spatial location of the access
std::vector<std::pair<Expr, ID>>
conds_; /// - first: The condition (predicate) of the access
/// - second: the statement that contribute to the condition)
Expand Down Expand Up @@ -365,16 +365,25 @@ class AnalyzeDeps {
const std::vector<Expr> &list,
RelaxMode relax,
GenPBExpr::VarMap &externals);
static std::string makeCond(GenPBExpr &genPBExpr,
const std::vector<std::pair<Expr, ID>> &conds,
RelaxMode relax, GenPBExpr::VarMap &externals,
static std::string makeCond(GenPBExpr &genPBExpr, RelaxMode relax,
GenPBExpr::VarMap &externals,
bool eraseOutsideVarDef, const AccessPoint &ap);
static PBMap makeAccMapStatic(PBCtx &presburger, const AccessPoint &p,
int iterDim, int accDim, RelaxMode relax,
const std::string &extSuffix,
GenPBExpr::VarMap &externals,
const ASTHashSet<Expr> &noNeedToBeVars,
bool eraseOutsideVarDef);

private:
PBMap makeAccMap(PBCtx &presburger, const AccessPoint &p, int iterDim,
int accDim, RelaxMode relax, const std::string &extSuffix,
GenPBExpr::VarMap &externals,
const ASTHashSet<Expr> &noNeedToBeVars);
const ASTHashSet<Expr> &noNeedToBeVars) {
return makeAccMapStatic(presburger, p, iterDim, accDim, relax,
extSuffix, externals, noNeedToBeVars,
eraseOutsideVarDef_);
}

PBMap makeEqForBothOps(PBCtx &presburger,
const std::vector<std::pair<int, int>> &coord,
Expand Down
14 changes: 14 additions & 0 deletions include/get_new_name.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef FREE_TENSOR_GET_NEW_NAME_H
#define FREE_TENSOR_GET_NEW_NAME_H

#include <string>
#include <unordered_set>

namespace freetensor {

std::string getNewName(const std::string &oldName,
const std::unordered_set<std::string> &used);

}

#endif // FREE_TENSOR_GET_NEW_NAME_H
2 changes: 2 additions & 0 deletions include/math/presburger.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ class PBSingleFunc {
public:
PBSingleFunc() {}
PBSingleFunc(isl_pw_aff *func) : func_(func) {}
explicit PBSingleFunc(isl_aff *func) : func_(isl_pw_aff_from_aff(func)) {}

~PBSingleFunc() {
if (func_ != nullptr) {
isl_pw_aff_free(func_);
Expand Down
10 changes: 8 additions & 2 deletions include/pass/shrink_for.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,15 @@ class ShrinkFor : public CompTransientBounds<SymbolTable<Mutator>> {
/**
* Increase the begin and decrease the end index, to remove redundant iterations
* from For loops
*
* @{
*/
Stmt shrinkFor(const Stmt &op, const Stmt &subAST = nullptr,
bool doSimplify = true);
Stmt shrinkFor(const Stmt &op, const ID &subAST = ID(), bool doSimplify = true);
inline Stmt shrinkFor(const Stmt &op, const Stmt &subAST,
bool doSimplify = true) {
return shrinkFor(op, subAST.isValid() ? subAST->id() : ID(), doSimplify);
}
/** @} */

DEFINE_PASS_FOR_FUNC(shrinkFor)

Expand Down
15 changes: 15 additions & 0 deletions include/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,21 @@ class Schedule {
void parallelize(const ID &loop, const ParallelScope &parallel,
bool allowReduction = true);

/**
* Parallelize a loop nest according to another loop nest to keep a tensor
* thread-local
*
* @param nest : ID of the loop nest to be parallelized. The ID can be of
* any statement type, and all statements it contains will be parallelized.
* @param reference: ID of the loop nest to be referenced. The ID can be of
* any statement type, and all statements it contains will be referenced.
* @param defId : ID of the VarDef statement of the tensor to be kept
* thread-local.
* @throw InvalidSchedule if any of the ID is not found, or the reference
* loop nest is already thread-non-local.
*/
void parallelizeAs(const ID &nest, const ID &reference, const ID &defId);

/**
* Unroll a loop
*
Expand Down
13 changes: 13 additions & 0 deletions include/schedule/parallelize_as.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef FREE_TENSOR_PARALLELIZE_AS_H
#define FREE_TENSOR_PARALLELIZE_AS_H

#include <stmt.h>

namespace freetensor {

Stmt parallelizeAs(const Stmt &ast, const ID &nest, const ID &reference,
const ID &defId);

} // namespace freetensor

#endif // FREE_TENSOR_PARALLELIZE_AS_H
17 changes: 9 additions & 8 deletions include/schedule/schedule_log.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ enum class ScheduleType : int {
VarReorder,
Inline,
Parallelize,
ParallelizeAs,
Unroll,
Vectorize,
SeparateTail,
Expand All @@ -42,14 +43,14 @@ enum class ScheduleType : int {
};

constexpr std::array scheduleTypeNames = {
"split", "reorder", "merge",
"fission", "fuse", "swap",
"blend", "cache", "cache_reduction",
"set_mem_type", "var_split", "var_merge",
"var_reorder", "inline", "parallelize",
"unroll", "vectorize", "separate_tail",
"as_matmul", "permute", "pluto_fuse",
"pluto_permute",
"split", "reorder", "merge",
"fission", "fuse", "swap",
"blend", "cache", "cache_reduction",
"set_mem_type", "var_split", "var_merge",
"var_reorder", "inline", "parallelize",
"parallelize_as", "unroll", "vectorize",
"separate_tail", "as_matmul", "permute",
"pluto_fuse", "pluto_permute",
};
static_assert(scheduleTypeNames.size() == (size_t)ScheduleType::NumTypes);

Expand Down
25 changes: 25 additions & 0 deletions python/freetensor/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,31 @@ def parallelize(self, loop, parallel):
"""
super().parallelize(self._lookup(loop), ParallelScope(parallel))

def parallelize_as(self, nest, reference, vardef):
'''
Parallelize a loop nest according to another loop nest to keep a tensor
thread-local
Parameters
----------
nest : str, ID or Stmt
The loop nest to be parallelized. The ID can be of any statement type,
and all statements it contains will be parallelized.
reference: str, ID or Stmt
The loop nest to be referenced. The ID can be of any statement type,
and all statements it contains will be referenced.
vardef : str, ID or Stmt
The VarDef statement of the tensor to be kept thread-local.
Raises
------
InvalidSchedule
if any of the ID is not found, or the reference loop nest is already
thread-non-local.
'''
super().parallelize_as(self._lookup(nest), self._lookup(reference),
self._lookup(vardef))

def unroll(self, loop, immediate=False):
"""
Unroll a loop
Expand Down
32 changes: 17 additions & 15 deletions src/analyze/comp_unique_bounds_pb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ std::optional<int64_t> CompUniqueBoundsPB::Bound::getInt() const {

namespace {

// Translate a computed bound function in ISL back to our Expr.
// We prefer min/max expressions as final results, so we first test if simply
// min/max of the pieces gives the correct result; if not, fallback to IfExpr.
Expr translateBoundFunc(
PBCtx &ctx, const PBSet &boundSet,
const std::unordered_map<std::string, Expr> &demangleMap) {
Expand Down Expand Up @@ -202,11 +199,10 @@ bool CompUniqueBoundsPB::alwaysLT(const Expr &lhs, const Expr &rhs) {
return combined.empty();
}

std::pair<Expr, Expr> CompUniqueBoundsPB::unionBounds(
Ref<CompUniqueBoundsPB::Bound> CompUniqueBoundsPB::unionBoundsAsBound(
const std::vector<Ref<CompUniqueBounds::Bound>> &_bounds) {
// if no bound presented, return an empty range
if (_bounds.size() == 0)
return {makeIntConst(0), makeIntConst(-1)};
return nullptr;

// PBSet in _bounds may be from foreign ctx. Reconstruct them in our ctx
auto bounds = ranges::to<std::vector>(
Expand All @@ -225,7 +221,7 @@ std::pair<Expr, Expr> CompUniqueBoundsPB::unionBounds(
bound = coalesce(std::move(bound));

// construct the demangle map
std::unordered_map<std::string, Expr> demangleMap;
auto demangleMap = Ref<std::unordered_map<std::string, Expr>>::make();
for (isl_size dim = 0; dim < bound.nParamDims(); ++dim) {
auto dimName = bound.nameParamDim(dim);
Expr demangled;
Expand All @@ -240,17 +236,23 @@ std::pair<Expr, Expr> CompUniqueBoundsPB::unionBounds(
}
}
}
demangleMap[dimName] = demangled;
(*demangleMap)[dimName] = demangled;
}

return Ref<CompUniqueBoundsPB::Bound>::make(ctx_, demangleMap, bound);
}

std::pair<Expr, Expr> CompUniqueBoundsPB::unionBounds(
const std::vector<Ref<CompUniqueBounds::Bound>> &bounds) {
auto bound = unionBoundsAsBound(bounds);

// if no bound presented, return an empty range
if (!bound.isValid()) {
return {makeIntConst(0), makeIntConst(-1)};
}

// translate the lower and upper bounds back to expression
auto l = bound.hasLowerBound(0)
? translateBoundFunc(*ctx_, lexmin(bound), demangleMap)
: nullptr;
auto u = bound.hasUpperBound(0)
? translateBoundFunc(*ctx_, lexmax(bound), demangleMap)
: nullptr;
return {l, u};
return {bound->lowerExpr(), bound->upperExpr()};
}

} // namespace freetensor
Loading

0 comments on commit 2811fe4

Please sign in to comment.