diff --git a/lib/Infer/ExhaustiveSynthesis.cpp b/lib/Infer/ExhaustiveSynthesis.cpp index 3d646cd99..6b4dafb69 100644 --- a/lib/Infer/ExhaustiveSynthesis.cpp +++ b/lib/Infer/ExhaustiveSynthesis.cpp @@ -50,6 +50,8 @@ namespace { static cl::opt EnableBigQuery("souper-exhaustive-synthesis-enable-big-query", cl::desc("Enable big query in exhaustive synthesis (default=false)"), cl::init(false)); + static cl::opt CostModel("souper-cost-model", cl::desc("Cost Model"), + cl::init("default")); } // TODO @@ -86,6 +88,60 @@ namespace { // experiment with synthesizing at reduced bitwidth, then expanding the result // aggressively avoid calling into the solver +using CostFunctionType = std::function; + +float defaultCostFunction(Inst* I, bool IgnoreDepsWithExternalUses) { + return souper::cost(I, IgnoreDepsWithExternalUses); +} + +// TODO : Derive these randomly to maximize infer rate with test-infer.sh +float getCost(Inst::Kind K) { + switch (K) { + case souper::Inst::Var: + case souper::Inst::Const: + case souper::Inst::Phi: + return 0; + case souper::Inst::BSwap: + case souper::Inst::CtPop: + case souper::Inst::Cttz: + case souper::Inst::Ctlz: + case souper::Inst::SDiv: + case souper::Inst::UDiv: + case souper::Inst::SRem: + case souper::Inst::URem: + return 5; + case souper::Inst::Select: + return 3; + case souper::Inst::Mul: + return 2.5f; + case souper::Inst::Add: + case souper::Inst::Sub: + return 1.5f; + default: + return 1; + } +} + +// TODO: Recognize simple patterns +static float costHelper(Inst *I, Inst *Root, std::set &Visited, + bool IgnoreDepsWithExternalUses) { + if (!Visited.insert(I).second) + return 0; + if (IgnoreDepsWithExternalUses && I != Root && + Root->DepsWithExternalUses.find(I) != Root->DepsWithExternalUses.end()) { + return 0; + } + float Cost = getCost(I->K); + for (auto Op : I->Ops) + Cost += costHelper(Op, Root, Visited, IgnoreDepsWithExternalUses); + return Cost; +} + +float simpleHeuristicsCostFunction(Inst* I, bool IgnoreDepsWithExternalUses) { + std::set Visited; + return costHelper(I, I, Visited, IgnoreDepsWithExternalUses); +} + void hasConstantHelper(Inst *I, std::set &Visited, std::vector &ConstList) { // FIXME this only works for one constant and keying by name is bad @@ -119,9 +175,9 @@ std::vector matchWidth(Inst *I, unsigned NewW, InstContext &IC) { return { I }; } -void addGuess(Inst *RHS, int MaxCost, std::vector &Guesses, - int &TooExpensive) { - if (souper::cost(RHS) < MaxCost) +void addGuess(Inst *RHS, float MaxCost, std::vector &Guesses, + float &TooExpensive, CostFunctionType &Cost) { + if (Cost(RHS, false) < MaxCost) Guesses.push_back(RHS); else TooExpensive++; @@ -138,9 +194,9 @@ bool prune (Inst *I, std::vector &ReservedInsts) { void getGuesses(std::vector &Guesses, const std::vector &Inputs, - int Width, int LHSCost, + int Width, float LHSCost, InstContext &IC, Inst *PrevInst, Inst *PrevSlot, - int &TooExpensive) { + float &TooExpensive, CostFunctionType &Cost) { std::vector PartialGuesses; @@ -177,7 +233,7 @@ void getGuesses(std::vector &Guesses, for (auto V : matchWidth(Comp, Width, IC)) { auto N = IC.getInst(K, Width, { V }); - addGuess(N, LHSCost, PartialGuesses, TooExpensive); + addGuess(N, LHSCost, PartialGuesses, TooExpensive, Cost); } } } @@ -264,7 +320,7 @@ void getGuesses(std::vector &Guesses, continue; auto N = IC.getInst(K, Inst::isCmp(K) ? 1 : OpWidth, { V1i, V2i }); for (auto MatchedWidthN : matchWidth(N, Width, IC)) { - addGuess(MatchedWidthN, LHSCost, PartialGuesses, TooExpensive); + addGuess(MatchedWidthN, LHSCost, PartialGuesses, TooExpensive, Cost); } } } @@ -318,7 +374,7 @@ void getGuesses(std::vector &Guesses, auto MatchedWidthL = matchWidth(L, 1, IC); auto SelectInst = IC.getInst(Inst::Select, Width, { MatchedWidthL[0], V1i, V2i }); - addGuess(SelectInst, LHSCost, PartialGuesses, TooExpensive); + addGuess(SelectInst, LHSCost, PartialGuesses, TooExpensive, Cost); } } } @@ -351,7 +407,7 @@ void getGuesses(std::vector &Guesses, if (prune(JoinedGuess, CurrSlots)) { for (auto S : CurrSlots) getGuesses(Guesses, Inputs, S->Width, - LHSCost, IC, JoinedGuess, S, TooExpensive); + LHSCost, IC, JoinedGuess, S, TooExpensive, Cost); } } } @@ -445,18 +501,24 @@ ExhaustiveSynthesis::synthesize(SMTLIBSolver *SMTSolver, if (DebugLevel > 1) llvm::errs() << "got " << Inputs.size() << " candidates from LHS\n"; + CostFunctionType Cost; + if (CostModel == "default") { + Cost = defaultCostFunction; + } else { + Cost = simpleHeuristicsCostFunction; + } - int LHSCost = souper::cost(LHS, /*IgnoreDepsWithExternalUses=*/true); + float LHSCost = Cost(LHS, /*IgnoreDepsWithExternalUses=*/true); - int TooExpensive = 0; + float TooExpensive = 0; std::vector Guesses; getGuesses(Guesses, Inputs, LHS->Width, - LHSCost, IC, nullptr, nullptr, TooExpensive); + LHSCost, IC, nullptr, nullptr, TooExpensive, Cost); // add nops guesses separately for (auto I : Inputs) { for (auto V : matchWidth(I, LHS->Width, IC)) - addGuess(V, LHSCost, Guesses, TooExpensive); + addGuess(V, LHSCost, Guesses, TooExpensive, Cost); } std::error_code EC; @@ -467,8 +529,8 @@ ExhaustiveSynthesis::synthesize(SMTLIBSolver *SMTSolver, // CEGIS is that we can synthesize in precisely increasing cost // order, and not try to somehow teach the solver how to do that std::stable_sort(Guesses.begin(), Guesses.end(), - [](Inst *a, Inst *b) -> bool { - return souper::cost(a) < souper::cost(b); + [&Cost](Inst *a, Inst *b) -> bool { + return Cost(a, false) < Cost(b, false); }); if (DebugLevel > 1) diff --git a/test/Infer/syn-shl-from-add.opt b/test/Infer/syn-shl-from-add.opt new file mode 100644 index 000000000..35454f261 --- /dev/null +++ b/test/Infer/syn-shl-from-add.opt @@ -0,0 +1,11 @@ +; REQUIRES: solver, solver-model + +; RUN: %souper-check %solver -reinfer-rhs -souper-exhaustive-synthesis -souper-exhaustive-synthesis-num-instructions=1 -souper-cost-model=heuristics %s > %t +; RUN: %FileCheck %s < %t + +%b:i64 = var +%a = add %b, %b +infer %a +%r2 = shl %b, 1 +result %r2 +; CHECK: RHS inferred successfully, no cost regression