Skip to content

Commit

Permalink
Enable spilling for partial aggregation (7558)
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Nov 21, 2023
1 parent b2fdc6e commit 8d8b7d5
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 18 deletions.
3 changes: 1 addition & 2 deletions velox/core/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ bool AggregationNode::canSpill(const QueryConfig& queryConfig) const {
}
// TODO: add spilling for pre-grouped aggregation later:
// https://github.com/facebookincubator/velox/issues/3264
return (isFinal() || isSingle()) && preGroupedKeys().empty() &&
queryConfig.aggregationSpillEnabled();
return preGroupedKeys().empty() && queryConfig.aggregationSpillEnabled();
}

void AggregationNode::addDetails(std::stringstream& stream) const {
Expand Down
8 changes: 4 additions & 4 deletions velox/core/tests/PlanFragmentTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,14 @@ TEST_F(PlanFragmentTest, aggregationCanSpill) {
{AggregationNode::Step::kSingle, true, true, false, false, true},
{AggregationNode::Step::kIntermediate, false, true, false, false, false},
{AggregationNode::Step::kIntermediate, true, false, false, false, false},
{AggregationNode::Step::kIntermediate, true, true, true, false, false},
{AggregationNode::Step::kIntermediate, true, true, true, false, true},
{AggregationNode::Step::kIntermediate, true, true, false, true, false},
{AggregationNode::Step::kIntermediate, true, true, false, false, false},
{AggregationNode::Step::kIntermediate, true, true, false, false, true},
{AggregationNode::Step::kPartial, false, true, false, false, false},
{AggregationNode::Step::kPartial, true, false, false, false, false},
{AggregationNode::Step::kPartial, true, true, true, false, false},
{AggregationNode::Step::kPartial, true, true, true, false, true},
{AggregationNode::Step::kPartial, true, true, false, true, false},
{AggregationNode::Step::kPartial, true, true, false, false, false},
{AggregationNode::Step::kPartial, true, true, false, false, true},
{AggregationNode::Step::kFinal, false, true, false, false, false},
{AggregationNode::Step::kFinal, true, false, false, false, false},
{AggregationNode::Step::kFinal, true, true, true, false, true},
Expand Down
8 changes: 6 additions & 2 deletions velox/exec/GroupingSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ const HashLookup& GroupingSet::hashLookup() const {
void GroupingSet::ensureInputFits(const RowVectorPtr& input) {
// Spilling is considered if this is a final or single aggregation and
// spillPath is set.
if (isPartial_ || spillConfig_ == nullptr) {
if (spillConfig_ == nullptr) {
return;
}

Expand Down Expand Up @@ -913,7 +913,7 @@ void GroupingSet::ensureOutputFits() {
// to reserve memory for the output as we can't reclaim much memory from this
// operator itself. The output processing can reclaim memory from the other
// operator or query through memory arbitration.
if (isPartial_ || spillConfig_ == nullptr || hasSpilled()) {
if (spillConfig_ == nullptr || hasSpilled()) {
return;
}

Expand Down Expand Up @@ -960,6 +960,10 @@ void GroupingSet::spill() {
return;
}

if (hasSpilled() && spiller_->finalized()) {
return;
}

if (!hasSpilled()) {
auto rows = table_->rows();
VELOX_DCHECK(pool_.trackUsage());
Expand Down
126 changes: 126 additions & 0 deletions velox/exec/tests/AggregationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "folly/experimental/EventCount.h"
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/common/file/FileSystems.h"
#include "velox/common/memory/Memory.h"
#include "velox/common/testutil/TestValue.h"
#include "velox/dwio/common/tests/utils/BatchMaker.h"
#include "velox/exec/Aggregate.h"
Expand Down Expand Up @@ -397,6 +398,33 @@ class AggregationTest : public OperatorTestBase {
VARCHAR()})};
folly::Random::DefaultGenerator rng_;
memory::MemoryReclaimer::Stats reclaimerStats_;

std::shared_ptr<core::QueryCtx> newQueryCtx(
int64_t memoryCapacity = memory::kMaxMemory) {
std::unordered_map<std::string, std::shared_ptr<Config>> configs;
std::shared_ptr<memory::MemoryPool> pool = memoryManager_->addRootPool(
"", memoryCapacity, MemoryReclaimer::create());
auto queryCtx = std::make_shared<core::QueryCtx>(
executor_.get(),
core::QueryConfig({}),
configs,
cache::AsyncDataCache::getInstance(),
std::move(pool));
return queryCtx;
}

void setupMemory() {
memory::MemoryManagerOptions options;
options.arbitratorKind = "SHARED";
options.checkUsageLeak = true;
memoryAllocator_ = memory::MemoryAllocator::createDefaultInstance();
options.allocator = memoryAllocator_.get();
memoryManager_ = std::make_unique<memory::MemoryManager>(options);
}

private:
std::shared_ptr<memory::MemoryAllocator> memoryAllocator_;
std::unique_ptr<memory::MemoryManager> memoryManager_;
};

template <>
Expand Down Expand Up @@ -847,6 +875,104 @@ TEST_F(AggregationTest, partialAggregationMemoryLimit) {
.customStats.count("flushRowCount"));
}

// TODO move to arbitrator test
TEST_F(AggregationTest, partialAggregationSpill) {
VectorFuzzer::Options fuzzerOpts;
fuzzerOpts.vectorSize = 128;
RowTypePtr rowType = ROW(
{{"c0", INTEGER()},
{"c1", INTEGER()},
{"c2", INTEGER()},
{"c3", INTEGER()},
{"c4", INTEGER()},
{"c5", INTEGER()},
{"c6", INTEGER()},
{"c7", INTEGER()},
{"c8", INTEGER()},
{"c9", INTEGER()},
{"c10", INTEGER()}});
VectorFuzzer fuzzer(std::move(fuzzerOpts), pool());

std::vector<RowVectorPtr> vectors;

const int32_t numVectors = 2000;
for (int i = 0; i < numVectors; i++) {
vectors.push_back(fuzzer.fuzzRow(rowType));
}

createDuckDbTable(vectors);

setupMemory();

core::PlanNodeId partialAggNodeId;
core::PlanNodeId finalAggNodeId;
// Set an artificially low limit on the amount of data to accumulate in
// the partial aggregation.

// Distinct aggregation.
auto spillDirectory1 = exec::test::TempDirectoryPath::create();
auto task = AssertQueryBuilder(duckDbQueryRunner_)
.queryCtx(newQueryCtx(10LL << 10 << 10))
.spillDirectory(spillDirectory1->path)
.config(QueryConfig::kSpillEnabled, "true")
.config(QueryConfig::kAggregationSpillEnabled, "true")
.config(
QueryConfig::kAggregationSpillMemoryThreshold,
std::to_string(0)) // always spill on final agg
.plan(PlanBuilder()
.values(vectors)
.partialAggregation({"c0"}, {})
.capturePlanNodeId(partialAggNodeId)
.finalAggregation()
.capturePlanNodeId(finalAggNodeId)
.planNode())
.assertResults("SELECT distinct c0 FROM tmp");

checkSpillStats(toPlanStats(task->taskStats()).at(partialAggNodeId), true);
checkSpillStats(toPlanStats(task->taskStats()).at(finalAggNodeId), true);

// Count aggregation.
auto spillDirectory2 = exec::test::TempDirectoryPath::create();
task = AssertQueryBuilder(duckDbQueryRunner_)
.queryCtx(newQueryCtx(10LL << 10 << 10))
.spillDirectory(spillDirectory2->path)
.config(QueryConfig::kSpillEnabled, "true")
.config(QueryConfig::kAggregationSpillEnabled, "true")
.config(
QueryConfig::kAggregationSpillMemoryThreshold,
std::to_string(0)) // always spill on final agg
.plan(PlanBuilder()
.values(vectors)
.partialAggregation({"c0"}, {"count(1)"})
.capturePlanNodeId(partialAggNodeId)
.finalAggregation()
.capturePlanNodeId(finalAggNodeId)
.planNode())
.assertResults("SELECT c0, count(1) FROM tmp GROUP BY 1");

checkSpillStats(toPlanStats(task->taskStats()).at(partialAggNodeId), true);
checkSpillStats(toPlanStats(task->taskStats()).at(finalAggNodeId), true);

// Global aggregation.
task = AssertQueryBuilder(duckDbQueryRunner_)
.queryCtx(newQueryCtx(10LL << 10 << 10))
.plan(PlanBuilder()
.values(vectors)
.partialAggregation({}, {"sum(c0)"})
.capturePlanNodeId(partialAggNodeId)
.finalAggregation()
.capturePlanNodeId(finalAggNodeId)
.planNode())
.assertResults("SELECT sum(c0) FROM tmp");
EXPECT_EQ(
0,
toPlanStats(task->taskStats())
.at(partialAggNodeId)
.customStats.count("flushRowCount"));
checkSpillStats(toPlanStats(task->taskStats()).at(partialAggNodeId), false);
checkSpillStats(toPlanStats(task->taskStats()).at(finalAggNodeId), false);
}

TEST_F(AggregationTest, partialDistinctWithAbandon) {
auto vectors = {
// 1st batch will produce 100 distinct groups from 10 rows.
Expand Down
33 changes: 23 additions & 10 deletions velox/functions/prestosql/aggregates/ApproxPercentileAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,19 @@ class ApproxPercentileAggregate : public exec::Aggregate {
DecodedVector decodedDigest_;

private:
bool isConstantVector(const VectorPtr& vec) {
if (vec->isConstantEncoding()) {
return true;
}
VELOX_USER_CHECK(vec->size() > 0);
for (vector_size_t i = 1; i < vec->size(); ++i) {
if (!vec->equalValueAt(vec.get(), i, 0)) {
return false;
}
}
return true;
}

template <bool kSingleGroup, bool checkIntermediateInputs>
void addIntermediateImpl(
std::conditional_t<kSingleGroup, char*, char**> group,
Expand All @@ -650,7 +663,8 @@ class ApproxPercentileAggregate : public exec::Aggregate {
if constexpr (checkIntermediateInputs) {
VELOX_USER_CHECK(rowVec);
for (int i = kPercentiles; i <= kAccuracy; ++i) {
VELOX_USER_CHECK(rowVec->childAt(i)->isConstantEncoding());
VELOX_USER_CHECK(isConstantVector(
rowVec->childAt(i))); // spilling flats constant encoding
}
for (int i = kK; i <= kMaxValue; ++i) {
VELOX_USER_CHECK(rowVec->childAt(i)->isFlatEncoding());
Expand All @@ -677,10 +691,9 @@ class ApproxPercentileAggregate : public exec::Aggregate {
}

DecodedVector percentiles(*rowVec->childAt(kPercentiles), *baseRows);
auto percentileIsArray =
rowVec->childAt(kPercentilesIsArray)->asUnchecked<SimpleVector<bool>>();
auto accuracy =
rowVec->childAt(kAccuracy)->asUnchecked<SimpleVector<double>>();
DecodedVector percentileIsArray(
*rowVec->childAt(kPercentilesIsArray), *baseRows);
DecodedVector accuracy(*rowVec->childAt(kAccuracy), *baseRows);
auto k = rowVec->childAt(kK)->asUnchecked<SimpleVector<int32_t>>();
auto n = rowVec->childAt(kN)->asUnchecked<SimpleVector<int64_t>>();
auto minValue = rowVec->childAt(kMinValue)->asUnchecked<SimpleVector<T>>();
Expand Down Expand Up @@ -710,7 +723,7 @@ class ApproxPercentileAggregate : public exec::Aggregate {
return;
}
int i = decoded.index(row);
if (percentileIsArray->isNullAt(i)) {
if (percentileIsArray.isNullAt(i)) {
return;
}
if (!accumulator) {
Expand All @@ -720,19 +733,19 @@ class ApproxPercentileAggregate : public exec::Aggregate {
percentilesBase->elements()->asFlatVector<double>();
if constexpr (checkIntermediateInputs) {
VELOX_USER_CHECK(percentileBaseElements);
VELOX_USER_CHECK(!percentilesBase->isNullAt(indexInBaseVector));
VELOX_USER_CHECK(!percentiles.isNullAt(indexInBaseVector));
}

bool isArray = percentileIsArray->valueAt(i);
bool isArray = percentileIsArray.valueAt<bool>(i);
const double* data;
vector_size_t len;
std::vector<bool> isNull;
extractPercentiles(
percentilesBase, indexInBaseVector, data, len, isNull);
checkSetPercentile(isArray, data, len, isNull);

if (!accuracy->isNullAt(i)) {
checkSetAccuracy(accuracy->valueAt(i));
if (!accuracy.isNullAt(i)) {
checkSetAccuracy(accuracy.valueAt<double>(i));
}
}
if constexpr (kSingleGroup) {
Expand Down

0 comments on commit 8d8b7d5

Please sign in to comment.