diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 685eb1d5cf075..9b294f681feab 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -531,23 +531,28 @@ bool MergeJoin::addToOutputForLeftJoin() { auto leftStart = l == firstLeftBatch ? leftStartIndex : 0; auto leftEnd = l == numLefts - 1 ? leftMatch_->endIndex : left->size(); + auto rightEnd = 0; + auto rightStart = 0; + auto firstRightBatch = 0; + auto rightStartIndex = 0; + auto numRights = 0; for (auto i = leftStart; i < leftEnd; ++i) { - auto firstRightBatch = + firstRightBatch = (l == firstLeftBatch && i == leftStart && rightMatch_->cursor) ? rightMatch_->cursor->batchIndex : 0; - auto rightStartIndex = + rightStartIndex = (l == firstLeftBatch && i == leftStart && rightMatch_->cursor) ? rightMatch_->cursor->index : rightMatch_->startIndex; - auto numRights = rightMatch_->inputs.size(); + numRights = rightMatch_->inputs.size(); + for (size_t r = firstRightBatch; r < numRights; ++r) { auto right = rightMatch_->inputs[r]; - auto rightStart = r == firstRightBatch ? rightStartIndex : 0; - auto rightEnd = - r == numRights - 1 ? rightMatch_->endIndex : right->size(); + rightStart = r == firstRightBatch ? rightStartIndex : 0; + rightEnd = r == numRights - 1 ? rightMatch_->endIndex : right->size(); if (prepareOutput(left, right)) { output_->resize(outputSize_); @@ -576,10 +581,56 @@ bool MergeJoin::addToOutputForLeftJoin() { rightMatch_->setCursor(r, j); return true; } + addOutputRow(left, i, right, j); } } } + + // Add a null value to the left side when there is no matching row on the + // right side after applying the filter. + if (isFullJoin(joinType_) && filter_) { + auto numRows = (leftEnd - leftStart) * (rightEnd - rightStart); + SelectivityVector matchingRows{outputSize_, false}; + matchingRows.setValidRange((outputSize_ - numRows), outputSize_, true); + matchingRows.updateBounds(); + + evaluateFilter(matchingRows); + + auto processedRowNums = (outputSize_ - numRows); + for (size_t r = firstRightBatch; r < numRights; ++r) { + auto right = rightMatch_->inputs[r]; + for (auto i = rightStart; i < rightEnd; ++i) { + bool rightMatched = false; + for (auto j = leftStart; j < leftEnd; ++j) { + auto rowIndex = processedRowNums + + (j - leftStart) * (rightEnd - rightStart) + i - rightStart; + const bool passed = !decodedFilterResult_.isNullAt(rowIndex) && + decodedFilterResult_.valueAt(rowIndex); + if (passed) { + rightMatched = passed; + } + } + + if (!rightMatched) { + if (!isRightFlattened_) { + rawRightIndices_[outputSize_] = i; + } else { + copyRow(right, i, output_, outputSize_, rightProjections_); + } + + for (const auto& projection : leftProjections_) { + const auto& target = output_->childAt(projection.outputChannel); + target->setNull(outputSize_, true); + } + + joinTracker_->addMiss(outputSize_); + + ++outputSize_; + } + } + } + } } leftMatch_.reset(); @@ -1128,8 +1179,6 @@ RowVectorPtr MergeJoin::doGetOutput() { RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { const auto numRows = output->size(); - RowVectorPtr fullOuterOutput = nullptr; - BufferPtr indices = allocateIndices(numRows, pool()); auto rawIndices = indices->asMutable(); vector_size_t numPassed = 0; @@ -1150,61 +1199,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (!isAntiJoin(joinType_)) { rawIndices[numPassed++] = row; - if (isFullJoin(joinType_)) { - // For filtered rows, it is necessary to insert additional data - // to ensure the result set is complete. Specifically, we - // need to generate two records: one record containing the - // columns from the left table along with nulls for the - // right table, and another record containing the columns - // from the right table along with nulls for the left table. - // For instance, the current output is filtered based on the condition - // t > 1. - - // 1, 1 - // 2, 2 - // 3, 3 - - // In this scenario, we need to additionally insert a record 1, 1. - // Subsequently, we will set the values of the columns on the left to - // null and the values of the columns on the right to null as well. By - // doing so, we will obtain the final result set. - - // 1, null - // null, 1 - // 2, 2 - // 3, 3 - fullOuterOutput = BaseVector::create( - output->type(), output->size() + 1, pool()); - - for (auto i = 0; i < row + 1; i++) { - for (auto j = 0; j < output->type()->size(); j++) { - fullOuterOutput->childAt(j)->copy( - output->childAt(j).get(), i, i, 1); - } - } - - for (auto j = 0; j < output->type()->size(); j++) { - fullOuterOutput->childAt(j)->copy( - output->childAt(j).get(), row + 1, row, 1); - } - - for (auto i = row + 1; i < output->size(); i++) { - for (auto j = 0; j < output->type()->size(); j++) { - fullOuterOutput->childAt(j)->copy( - output->childAt(j).get(), i + 1, i, 1); - } - } - - for (auto& projection : leftProjections_) { - auto target = fullOuterOutput->childAt(projection.outputChannel); - target->setNull(row, true); - } - - for (auto& projection : rightProjections_) { - auto target = fullOuterOutput->childAt(projection.outputChannel); - target->setNull(row + 1, true); - } - } else if (!isRightJoin(joinType_)) { + if (!isRightJoin(joinType_)) { for (auto& projection : rightProjections_) { auto target = output->childAt(projection.outputChannel); target->setNull(row, true); @@ -1284,17 +1279,9 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (numPassed == numRows) { // All rows passed. - if (fullOuterOutput) { - return fullOuterOutput; - } return output; } - // Some, but not all rows passed. - if (fullOuterOutput) { - return wrap(numPassed, indices, fullOuterOutput); - } - return wrap(numPassed, indices, output); } diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 72062c6416790..e8182bdc04f50 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -508,6 +508,45 @@ TEST_F(MergeJoinTest, leftAndRightJoinFilter) { } } +TEST_F(MergeJoinTest, fullOuterJoinWithDuplicateMatch) { + // Each row on the left side has at most one match on the right side. + auto left = makeRowVector( + {"a", "b"}, + { + makeNullableFlatVector({1, 2, 2, 2, 3, 5, 6, std::nullopt}), + makeNullableFlatVector( + {2.0, 100.0, 1.0, 1.0, 3.0, 1.0, 6.0, std::nullopt}), + }); + + auto right = makeRowVector( + {"c", "d"}, + { + makeNullableFlatVector( + {0, 2, 2, 2, 2, 3, 4, 5, 7, std::nullopt}), + makeNullableFlatVector( + {0.0, 3.0, -1.0, -1.0, 3.0, 2.0, 1.0, 3.0, 7.0, std::nullopt}), + }); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + + auto rightPlan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"a"}, + {"c"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "b < d", + {"a", "b", "c", "d"}, + core::JoinType::kFull) + .planNode(); + AssertQueryBuilder(rightPlan, duckDbQueryRunner_) + .assertResults("SELECT * from t FULL OUTER JOIN u ON a = c AND b < d"); +} + TEST_F(MergeJoinTest, rightJoinWithDuplicateMatch) { // Each row on the left side has at most one match on the right side. auto left = makeRowVector(