Skip to content

Commit

Permalink
fixed polymorphism compatibility with polynomial normalization of terms
Browse files Browse the repository at this point in the history
  • Loading branch information
joe-hauns committed Oct 4, 2024
1 parent 481cc35 commit bb3951e
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 45 deletions.
80 changes: 57 additions & 23 deletions Kernel/BottomUpEvaluation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#ifndef __LIB__BOTTOM_UP_EVALUATION_HPP__
#define __LIB__BOTTOM_UP_EVALUATION_HPP__

#define DEBUG(...) // DBG(__VA_ARGS__)

/**
* @file Kernel/BottomUpEvaluation.hpp
Expand All @@ -28,7 +27,9 @@
#include "Lib/Recycled.hpp"
#include "Lib/Option.hpp"
#include "Lib/TypeList.hpp"
#include "Debug/Tracer.hpp"
#include <utility>
#define DEBUG_BOTTOM_UP(lvl, ...) if (lvl < 0) DBG(__VA_ARGS__)

namespace Lib {
using EmptyContext = std::tuple<>;
Expand Down Expand Up @@ -351,14 +352,15 @@ class BottomUpEvaluation {

BottomUpChildIter<Arg> orig = recState->pop();

ASS_GE(recResults.size(), orig.nChildren(_context))
Result* argLst = orig.nChildren(_context) == 0
? nullptr
: static_cast<Result*>(&((*recResults)[recResults->size() - orig.nChildren(_context)]));

Result eval = _memo.getOrInit(orig.self(),
[&](){ return _function(orig.self(), argLst); });

DEBUG("evaluated: ", orig.self(), " -> ", eval);
DEBUG_BOTTOM_UP(0, "evaluated: ", orig.self(), " -> ", eval);
recResults->pop(orig.nChildren(_context));
recResults->push(std::move(eval));
}
Expand All @@ -368,17 +370,23 @@ class BottomUpEvaluation {

ASS(recResults->size() == 1);
auto result = recResults->pop();
DEBUG("eval result: ", toEval, " -> ", result);
DEBUG_BOTTOM_UP(0, "eval result: ", toEval, " -> ", result);
return result;
}
};


}
#undef DEBUG

#include "Kernel/Term.hpp"

namespace Lib {

struct TermListContext {
bool ignoreTypeArgs = true;
};

// specialisation for TermList
// iterate up through TermLists, ignoring sort arguments
template<>
Expand All @@ -387,20 +395,29 @@ struct BottomUpChildIter<Kernel::TermList>
Kernel::TermList _self;
unsigned _idx;

BottomUpChildIter(Kernel::TermList self, EmptyContext = EmptyContext()) : _self(self), _idx(0)
BottomUpChildIter(Kernel::TermList self, TermListContext c) : _self(self), _idx(0)
{ }
BottomUpChildIter(Kernel::TermList self, EmptyContext = EmptyContext()) : BottomUpChildIter(self, TermListContext()) {}

Kernel::TermList next(EmptyContext = EmptyContext())
{ return next(TermListContext()); }

Kernel::TermList next(TermListContext ctx)
{
ASS(hasNext());
return _self.term()->termArg(_idx++);
ASS(hasNext(ctx));
return ctx.ignoreTypeArgs ? _self.term()->termArg(_idx++)
: *_self.term()->nthArgument(_idx++);
}

bool hasNext(EmptyContext = EmptyContext()) const
bool hasNext(EmptyContext = EmptyContext()) const { return hasNext(TermListContext()); }
bool hasNext(TermListContext c) const
{ return _self.isTerm() && _idx < _self.term()->numTermArguments(); }

unsigned nChildren(EmptyContext = EmptyContext()) const
{ return _self.isVar() ? 0 : _self.term()->numTermArguments(); }
unsigned nChildren(EmptyContext = EmptyContext()) const { return nChildren(TermListContext()); }
unsigned nChildren(TermListContext c) const
{ return _self.isVar() ? 0
: ( c.ignoreTypeArgs ? _self.term()->numTermArguments()
: _self.term()->arity()); }

Kernel::TermList self(EmptyContext = EmptyContext()) const
{ return _self; }
Expand All @@ -410,35 +427,53 @@ struct BottomUpChildIter<Kernel::TermList>
#include "TypedTermList.hpp"

namespace Lib {

// specialisation for TypedTermList
template<>
struct BottomUpChildIter<Kernel::TypedTermList>
{
Kernel::TypedTermList _self;
unsigned _idx;

BottomUpChildIter(Kernel::TypedTermList self, EmptyContext = EmptyContext()) : _self(self), _idx(0)
BottomUpChildIter(Kernel::TypedTermList self, EmptyContext) : BottomUpChildIter(self, TermListContext()) {}
BottomUpChildIter(Kernel::TypedTermList self, TermListContext ctx) : _self(self), _idx(0)
{}

Kernel::TypedTermList next(int);
Kernel::TypedTermList next(EmptyContext = EmptyContext())
Kernel::TypedTermList next(EmptyContext) { return next(TermListContext()); }
Kernel::TypedTermList next(TermListContext ctx)
{
ASS(hasNext());
ASS(hasNext(ctx));
auto cur = self().term();
auto next = cur->termArg(_idx);
auto sort = Kernel::SortHelper::getTermArgSort(cur, _idx);
ASS_NEQ(sort, Kernel::AtomicSort::superSort())
Kernel::TypedTermList out;
if (ctx.ignoreTypeArgs) {
out = Kernel::TypedTermList(cur->termArg(_idx),
Kernel::SortHelper::getTermArgSort(cur, _idx));
ASS_NEQ(out.sort(), Kernel::AtomicSort::superSort())
} else {
out = Kernel::TypedTermList(*cur->nthArgument(_idx),
Kernel::SortHelper::getArgSort(cur, _idx));
}
_idx++;
return Kernel::TypedTermList(next, sort);
return out;
}

bool hasNext(EmptyContext = EmptyContext()) const
{ return _self.isTerm() && _idx < _self.term()->numTermArguments(); }

unsigned nChildren(EmptyContext = EmptyContext()) const
{ return _self.isVar() ? 0 : _self.term()->numTermArguments(); }
bool hasNext(EmptyContext) const { return hasNext(TermListContext()); }
bool hasNext(TermListContext ctx) const
{ return _self.isTerm() && (ctx.ignoreTypeArgs
? _idx < _self.term()->numTermArguments()
: _idx < _self.term()->arity()); }

unsigned nChildren(EmptyContext) const { return nChildren(TermListContext()); }
unsigned nChildren(TermListContext ctx = TermListContext{}) const
{
return _self.isVar() ? 0
: (ctx.ignoreTypeArgs ? _self.term()->numTermArguments()
: _self.term()->arity());
}

Kernel::TypedTermList self(EmptyContext = EmptyContext()) const
Kernel::TypedTermList self(EmptyContext) const { return self(TermListContext()); }
Kernel::TypedTermList self(TermListContext ctx = TermListContext{}) const
{ return _self; }
};

Expand Down Expand Up @@ -583,6 +618,5 @@ struct BottomUpChildIter<Kernel::PolyNf>

} // namespace Lib

#undef DEBUG
#endif // __LIB__BOTTOM_UP_EVALUATION_HPP__

17 changes: 9 additions & 8 deletions Kernel/Polynomial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "Kernel/Polynomial.hpp"
#include "Kernel/PolynomialNormalizer.hpp"
#include "Debug/Output.hpp"

#define DEBUG(...) // DBG(__VA_ARGS__)

Expand Down Expand Up @@ -47,22 +48,22 @@ std::ostream& operator<<(std::ostream& out, const Variable& self)
// impl FuncId
/////////////////////////////////////////////////////////

FuncId::FuncId(unsigned num, const TermList* typeArgs) : _num(num) /*, _typeArgs(typeArgs)*/ {}
FuncId::FuncId(unsigned num, const TermList* typeArgs) : _num(num), _typeArgs(typeArgs) {}

FuncId FuncId::symbolOf(Term* term)
{ return FuncId(term->functor(), term->typeArgs()); }

unsigned FuncId::numTermArguments()
{ return symbol()->numTermArguments(); }

bool operator==(FuncId const& lhs, FuncId const& rhs)
{ return lhs._num == rhs._num; }

bool operator!=(FuncId const& lhs, FuncId const& rhs)
{ return !(lhs == rhs); }

std::ostream& operator<<(std::ostream& out, const FuncId& self)
{ return out << self.symbol()->name(); }
{
if (self.numTypeArgs() == 0) {
return out << self.symbol()->name();
} else {
return out << self.symbol()->name() << "<" << outputInterleaved(", ", self.iterTypeArgs()) << ">";
}
}

Signature::Symbol* FuncId::symbol() const
{ return env.signature->getFunction(_num); }
Expand Down
23 changes: 11 additions & 12 deletions Kernel/Polynomial.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "Kernel/NumTraits.hpp"
#include "Kernel/Ordering.hpp"
#include "Kernel/TypedTermList.hpp"
#include "Lib/Reflection.hpp"
#include <type_traits>

#define DEBUG(...) // DBG(__VA_ARGS__)
Expand Down Expand Up @@ -76,17 +77,18 @@ namespace Kernel {
class FuncId
{
unsigned _num;
// const TermList* _typeArgs; // private field not used
const TermList* _typeArgs;

public:
explicit FuncId(unsigned num, const TermList* typeArgs);
static FuncId symbolOf(Term* term);
unsigned numTermArguments();
TermList typeArg(unsigned i) const { return *(_typeArgs - i); }
unsigned numTypeArgs() const { return env.signature->getFunction(_num)->numTypeArguments(); }

friend struct std::hash<FuncId>;
friend bool operator==(FuncId const& lhs, FuncId const& rhs);
friend bool operator!=(FuncId const& lhs, FuncId const& rhs);
friend std::ostream& operator<<(std::ostream& out, const FuncId& self);
auto iterTypeArgs() const
{ return range(0, numTypeArgs()).map([&](auto i) { return typeArg(i); }); }

Signature::Symbol* symbol() const;

Expand All @@ -97,18 +99,15 @@ class FuncId

template<class Number>
Option<typename Number::ConstantType> tryNumeral() const;

auto asTuple() const { return std::tuple(_num, iterContOps(iterTypeArgs())); }
IMPL_COMPARISONS_FROM_TUPLE(FuncId)
IMPL_HASH_FROM_TUPLE(FuncId)
};

} // namespace Kernel


template<> struct std::hash<Kernel::FuncId>
{
size_t operator()(Kernel::FuncId const& f) const
{ return std::hash<unsigned>{}(f._num); }
};


/////////////////////////////////////////////////////////////////////////////////////////////
// forward declarations, needed to define PolyNf structure
/////////////////////////////////////////////////////////
Expand Down Expand Up @@ -621,7 +620,7 @@ Option<typename Number::ConstantType> FuncTerm::tryNumeral() const
template<> struct std::hash<Kernel::FuncTerm>
{
size_t operator()(Kernel::FuncTerm const& f) const
{ return Lib::HashUtils::combine(std::hash<Kernel::FuncId>{}(f._fun), std::hash<Stack<Kernel::PolyNf>>{}(f._args)); }
{ return Lib::HashUtils::combine(f._fun.defaultHash(), std::hash<Stack<Kernel::PolyNf>>{}(f._args)); }
};

/////////////////////////////////////////////////////////
Expand Down
9 changes: 7 additions & 2 deletions Kernel/PolynomialNormalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "PolynomialNormalizer.hpp"
#include "Kernel/BottomUpEvaluation.hpp"

#define DEBUG(...) //DBG(__VA_ARGS__)
#define DEBUG(...) // DBG(__VA_ARGS__)

namespace Kernel {

Expand Down Expand Up @@ -363,7 +363,12 @@ TermList PolyNf::denormalize() const
.function(
[&](PolyNf orig, TermList* results) -> TermList
{ return orig.match(
[&](Perfect<FuncTerm> t) { return TermList(Term::create(t->function().id(), t->numTermArguments(), results)); },
[&](Perfect<FuncTerm> t) {
return TermList(Term::createFromIter(t->function().id(),
concatIters(
t->function().iterTypeArgs(),
range(0, t->numTermArguments()).map([&](auto i){ return results[i]; })
))); },
[&](Variable v) { return TermList::var(v.id()); },
[&](AnyPoly p) { return p.denormalize(results); }
); })
Expand Down
37 changes: 37 additions & 0 deletions Lib/Metaiterators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1704,4 +1704,41 @@ template<class Inner>
auto getPersistentIterator(Inner it)
{ return pvi(arrayIter(iterTraits(it).template collect<Stack>())); }

template<class Iter>
class IterContOps {
Iter const _iter;

public:
IterContOps(Iter iter) : _iter(std::move(iter)) {}

auto defaultHash() const { return DefaultHash::hashIter(Iter(_iter).map([](ELEMENT_TYPE(Iter) x) -> unsigned { return DefaultHash::hash(x); })); }
auto defaultHash2() const { return DefaultHash::hashIter(Iter(_iter).map([](ELEMENT_TYPE(Iter) x) -> unsigned { return DefaultHash2::hash(x); })); }

static int cmp(IterContOps const& lhs, IterContOps const& rhs) {
auto l = lhs._iter;
auto r = rhs._iter;
while (l.hasNext() && r.hasNext()) {
auto ln = l.next();
auto rn = r.next();
if (ln < rn) {
return -1;
} else if (rn < ln) {
return 1;
}
}
return !l.hasNext() ? (r.hasNext() ? -1 : 0) : 1;
}
friend bool operator<(IterContOps const& lhs, IterContOps const& rhs)
{ return cmp(lhs, rhs) < 0; }

friend bool operator==(IterContOps const& lhs, IterContOps const& rhs)
{ return cmp(lhs, rhs) == 0; }

friend bool operator!=(IterContOps const& lhs, IterContOps const& rhs)
{ return !(lhs == rhs); }
};

template<class Iter>
auto iterContOps(Iter iter) { return IterContOps<Iter>(std::move(iter)); }

#endif /* __Metaiterators__ */

0 comments on commit bb3951e

Please sign in to comment.