From f57c265b31948ac982b54ba783fe96713ff77999 Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Tue, 15 Oct 2024 18:02:46 +0800 Subject: [PATCH] fix semi join result mismatch issue --- velox/exec/MergeJoin.cpp | 90 +++++++++++++++++++++++++++--- velox/exec/tests/MergeJoinTest.cpp | 45 +++++++++++++++ 2 files changed, 128 insertions(+), 7 deletions(-) diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 685eb1d5cf075..a1a3cf5eac18b 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -89,7 +89,9 @@ void MergeJoin::initialize() { initializeFilter(joinNode_->filter(), leftType, rightType); if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || - joinNode_->isRightJoin() || joinNode_->isFullJoin()) { + joinNode_->isRightJoin() || joinNode_->isFullJoin() || + joinNode_->isLeftSemiFilterJoin() || + joinNode_->isRightSemiFilterJoin()) { joinTracker_ = JoinTracker(outputBatchSize_, pool()); } } else if (joinNode_->isAntiJoin()) { @@ -543,6 +545,12 @@ bool MergeJoin::addToOutputForLeftJoin() { : rightMatch_->startIndex; auto numRights = rightMatch_->inputs.size(); + + if (isLeftSemiFilterJoin(joinType_) && !filter_) { + // LeftSemiFilter produce each row from the left at most once. + numRights = 1; + } + for (size_t r = firstRightBatch; r < numRights; ++r) { auto right = rightMatch_->inputs[r]; auto rightStart = r == firstRightBatch ? rightStartIndex : 0; @@ -560,7 +568,7 @@ bool MergeJoin::addToOutputForLeftJoin() { // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - if (isLeftSemiFilterJoin(joinType_)) { + if (isLeftSemiFilterJoin(joinType_) && !filter_) { // LeftSemiFilter produce each row from the left at most once. rightEnd = rightStart + 1; } @@ -578,6 +586,35 @@ bool MergeJoin::addToOutputForLeftJoin() { } addOutputRow(left, i, right, j); } + + if (isLeftSemiFilterJoin(joinType_) && filter_) { + auto numRows = (rightEnd - rightStart); + SelectivityVector matchingRows{outputSize_, false}; + matchingRows.setValidRange( + (outputSize_ - numRows), outputSize_, true); + matchingRows.updateBounds(); + + evaluateFilter(matchingRows); + + auto processedRowNums = (outputSize_ - numRows); + + auto firstMatchedRow = false; + for (auto j = rightStart; j < rightEnd; ++j) { + auto rowIndex = processedRowNums + j - rightStart; + const bool passed = !decodedFilterResult_.isNullAt(rowIndex) && + decodedFilterResult_.valueAt(rowIndex); + if (passed) { + if (!firstMatchedRow) { + firstMatchedRow = true; + } else { + joinTracker_->addMiss(rowIndex); + } + } + } + if (firstMatchedRow) { + break; + } + } } } } @@ -622,6 +659,12 @@ bool MergeJoin::addToOutputForRightJoin() { : leftMatch_->startIndex; auto numLefts = leftMatch_->inputs.size(); + + if (isRightSemiFilterJoin(joinType_) && !filter_) { + // RightSemiFilter produce each row from the left at most once. + numRights = 1; + } + for (size_t l = firstLeftBatch; l < numLefts; ++l) { auto left = leftMatch_->inputs[l]; auto leftStart = l == firstLeftBatch ? leftStartIndex : 0; @@ -638,7 +681,7 @@ bool MergeJoin::addToOutputForRightJoin() { // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - if (isRightSemiFilterJoin(joinType_)) { + if (isRightSemiFilterJoin(joinType_) && !filter_) { // RightSemiFilter produce each row from the right at most once. leftEnd = leftStart + 1; } @@ -656,6 +699,35 @@ bool MergeJoin::addToOutputForRightJoin() { } addOutputRow(left, j, right, i); } + + if (isRightSemiFilterJoin(joinType_) && filter_) { + auto numRows = (leftEnd - leftStart); + SelectivityVector matchingRows{outputSize_, false}; + matchingRows.setValidRange( + (outputSize_ - numRows), outputSize_, true); + matchingRows.updateBounds(); + + evaluateFilter(matchingRows); + + auto processedRowNums = (outputSize_ - numRows); + + auto firstMatchedRow = false; + for (auto j = leftStart; j < leftEnd; ++j) { + auto rowIndex = processedRowNums + j - leftStart; + const bool passed = !decodedFilterResult_.isNullAt(rowIndex) && + decodedFilterResult_.valueAt(rowIndex); + if (passed) { + if (!firstMatchedRow) { + firstMatchedRow = true; + } else { + joinTracker_->addMiss(rowIndex); + } + } + } + if (firstMatchedRow) { + break; + } + } } } } @@ -1147,7 +1219,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // If all matches for a given left-side row fail the filter, add a row to // the output with nulls for the right-side columns. auto onMiss = [&](auto row) { - if (!isAntiJoin(joinType_)) { + if (!isAntiJoin(joinType_) && !isRightSemiFilterJoin(joinType_) && + !isLeftSemiFilterJoin(joinType_)) { rawIndices[numPassed++] = row; if (isFullJoin(joinType_)) { @@ -1235,9 +1308,12 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { } } } else { - // This row doesn't have a match on the right side. Keep it - // unconditionally. - rawIndices[numPassed++] = i; + if (!isLeftSemiFilterJoin(joinType_) && + !isRightSemiFilterJoin(joinType_)) { + // This row doesn't have a match on the right side. Keep it + // unconditionally. + rawIndices[numPassed++] = i; + } } } diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 72062c6416790..d4d3ffcbcd228 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -756,6 +756,51 @@ TEST_F(MergeJoinTest, semiJoin) { core::JoinType::kRightSemiFilter); } +TEST_F(MergeJoinTest, leftSemiJoin) { + auto left = makeRowVector( + {"t0", "t1"}, + {makeNullableFlatVector({"val1b", "val1b", "val1b"}), + makeNullableFlatVector({12, 12, 12})}); + + auto right = makeRowVector( + {"u0", "u1"}, + {makeNullableFlatVector({"val1b", "val1b", "val1b"}), + makeNullableFlatVector({12, 16, 16})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto testSemiJoin = [&](const std::string& filter, + const std::string& sql, + const std::vector& outputLayout, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + filter, + outputLayout, + joinType) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_).assertResults(sql); + }; + + testSemiJoin( + "t1 < u1", + "SELECT t0, t1 FROM t where t0 IN (SELECT u0 from u where t1 < u1)", + {"t0", "t1"}, + core::JoinType::kLeftSemiFilter); + testSemiJoin( + "t1 < u1", + "SELECT u0, u1 FROM u where u0 IN (SELECT t0 from t where t1 < u1)", + {"u0", "u1"}, + core::JoinType::kRightSemiFilter); +} + TEST_F(MergeJoinTest, rightJoin) { auto left = makeRowVector( {"t0"},