Skip to content

Commit

Permalink
Fix full outer join result mismatch issue
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Sep 25, 2024
1 parent 7483a76 commit 30158db
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 71 deletions.
129 changes: 58 additions & 71 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,23 +531,28 @@ bool MergeJoin::addToOutputForLeftJoin() {
auto leftStart = l == firstLeftBatch ? leftStartIndex : 0;
auto leftEnd = l == numLefts - 1 ? leftMatch_->endIndex : left->size();

auto rightEnd = 0;
auto rightStart = 0;
auto firstRightBatch = 0;
auto rightStartIndex = 0;
auto numRights = 0;
for (auto i = leftStart; i < leftEnd; ++i) {
auto firstRightBatch =
firstRightBatch =
(l == firstLeftBatch && i == leftStart && rightMatch_->cursor)
? rightMatch_->cursor->batchIndex
: 0;

auto rightStartIndex =
rightStartIndex =
(l == firstLeftBatch && i == leftStart && rightMatch_->cursor)
? rightMatch_->cursor->index
: rightMatch_->startIndex;

auto numRights = rightMatch_->inputs.size();
numRights = rightMatch_->inputs.size();

for (size_t r = firstRightBatch; r < numRights; ++r) {
auto right = rightMatch_->inputs[r];
auto rightStart = r == firstRightBatch ? rightStartIndex : 0;
auto rightEnd =
r == numRights - 1 ? rightMatch_->endIndex : right->size();
rightStart = r == firstRightBatch ? rightStartIndex : 0;
rightEnd = r == numRights - 1 ? rightMatch_->endIndex : right->size();

if (prepareOutput(left, right)) {
output_->resize(outputSize_);
Expand Down Expand Up @@ -576,10 +581,56 @@ bool MergeJoin::addToOutputForLeftJoin() {
rightMatch_->setCursor(r, j);
return true;
}

addOutputRow(left, i, right, j);
}
}
}

// Add a null value to the left side when there is no matching row on the
// right side after applying the filter.
if (isFullJoin(joinType_) && filter_) {
auto numRows = (leftEnd - leftStart) * (rightEnd - rightStart);
SelectivityVector matchingRows{outputSize_, false};
matchingRows.setValidRange((outputSize_ - numRows), outputSize_, true);
matchingRows.updateBounds();

evaluateFilter(matchingRows);

auto processedRowNums = (outputSize_ - numRows);
for (size_t r = firstRightBatch; r < numRights; ++r) {
auto right = rightMatch_->inputs[r];
for (auto i = rightStart; i < rightEnd; ++i) {
bool rightMatched = false;
for (auto j = leftStart; j < leftEnd; ++j) {
auto rowIndex = processedRowNums +
(j - leftStart) * (rightEnd - rightStart) + i - rightStart;
const bool passed = !decodedFilterResult_.isNullAt(rowIndex) &&
decodedFilterResult_.valueAt<bool>(rowIndex);
if (passed) {
rightMatched = passed;
}
}

if (!rightMatched) {
if (!isRightFlattened_) {
rawRightIndices_[outputSize_] = i;
} else {
copyRow(right, i, output_, outputSize_, rightProjections_);
}

for (const auto& projection : leftProjections_) {
const auto& target = output_->childAt(projection.outputChannel);
target->setNull(outputSize_, true);
}

joinTracker_->addMiss(outputSize_);

++outputSize_;
}
}
}
}
}

leftMatch_.reset();
Expand Down Expand Up @@ -1128,8 +1179,6 @@ RowVectorPtr MergeJoin::doGetOutput() {
RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
const auto numRows = output->size();

RowVectorPtr fullOuterOutput = nullptr;

BufferPtr indices = allocateIndices(numRows, pool());
auto rawIndices = indices->asMutable<vector_size_t>();
vector_size_t numPassed = 0;
Expand All @@ -1150,61 +1199,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
if (!isAntiJoin(joinType_)) {
rawIndices[numPassed++] = row;

if (isFullJoin(joinType_)) {
// For filtered rows, it is necessary to insert additional data
// to ensure the result set is complete. Specifically, we
// need to generate two records: one record containing the
// columns from the left table along with nulls for the
// right table, and another record containing the columns
// from the right table along with nulls for the left table.
// For instance, the current output is filtered based on the condition
// t > 1.

// 1, 1
// 2, 2
// 3, 3

// In this scenario, we need to additionally insert a record 1, 1.
// Subsequently, we will set the values of the columns on the left to
// null and the values of the columns on the right to null as well. By
// doing so, we will obtain the final result set.

// 1, null
// null, 1
// 2, 2
// 3, 3
fullOuterOutput = BaseVector::create<RowVector>(
output->type(), output->size() + 1, pool());

for (auto i = 0; i < row + 1; i++) {
for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i, i, 1);
}
}

for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), row + 1, row, 1);
}

for (auto i = row + 1; i < output->size(); i++) {
for (auto j = 0; j < output->type()->size(); j++) {
fullOuterOutput->childAt(j)->copy(
output->childAt(j).get(), i + 1, i, 1);
}
}

for (auto& projection : leftProjections_) {
auto target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row, true);
}

for (auto& projection : rightProjections_) {
auto target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row + 1, true);
}
} else if (!isRightJoin(joinType_)) {
if (!isRightJoin(joinType_)) {
for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
Expand Down Expand Up @@ -1284,17 +1279,9 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

if (numPassed == numRows) {
// All rows passed.
if (fullOuterOutput) {
return fullOuterOutput;
}
return output;
}

// Some, but not all rows passed.
if (fullOuterOutput) {
return wrap(numPassed, indices, fullOuterOutput);
}

return wrap(numPassed, indices, output);
}

Expand Down
39 changes: 39 additions & 0 deletions velox/exec/tests/MergeJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,45 @@ TEST_F(MergeJoinTest, leftAndRightJoinFilter) {
}
}

TEST_F(MergeJoinTest, fullOuterJoinWithDuplicateMatch) {
// Each row on the left side has at most one match on the right side.
auto left = makeRowVector(
{"a", "b"},
{
makeNullableFlatVector<int32_t>({1, 2, 2, 2, 3, 5, 6, std::nullopt}),
makeNullableFlatVector<double>(
{2.0, 100.0, 1.0, 1.0, 3.0, 1.0, 6.0, std::nullopt}),
});

auto right = makeRowVector(
{"c", "d"},
{
makeNullableFlatVector<int32_t>(
{0, 2, 2, 2, 2, 3, 4, 5, 7, std::nullopt}),
makeNullableFlatVector<double>(
{0.0, 3.0, -1.0, -1.0, 3.0, 2.0, 1.0, 3.0, 7.0, std::nullopt}),
});

createDuckDbTable("t", {left});
createDuckDbTable("u", {right});

auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();

auto rightPlan =
PlanBuilder(planNodeIdGenerator)
.values({left})
.mergeJoin(
{"a"},
{"c"},
PlanBuilder(planNodeIdGenerator).values({right}).planNode(),
"b < d",
{"a", "b", "c", "d"},
core::JoinType::kFull)
.planNode();
AssertQueryBuilder(rightPlan, duckDbQueryRunner_)
.assertResults("SELECT * from t FULL OUTER JOIN u ON a = c AND b < d");
}

TEST_F(MergeJoinTest, rightJoinWithDuplicateMatch) {
// Each row on the left side has at most one match on the right side.
auto left = makeRowVector(
Expand Down

0 comments on commit 30158db

Please sign in to comment.