diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 685eb1d5cf075..54d38c96aa856 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/exec/MergeJoin.h" +#include #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" @@ -89,7 +90,9 @@ void MergeJoin::initialize() { initializeFilter(joinNode_->filter(), leftType, rightType); if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || - joinNode_->isRightJoin() || joinNode_->isFullJoin()) { + joinNode_->isRightJoin() || joinNode_->isFullJoin() || + joinNode_->isLeftSemiFilterJoin() || + joinNode_->isRightSemiFilterJoin()) { joinTracker_ = JoinTracker(outputBatchSize_, pool()); } } else if (joinNode_->isAntiJoin()) { @@ -560,7 +563,7 @@ bool MergeJoin::addToOutputForLeftJoin() { // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - if (isLeftSemiFilterJoin(joinType_)) { + if (isLeftSemiFilterJoin(joinType_) && !filter_) { // LeftSemiFilter produce each row from the left at most once. rightEnd = rightStart + 1; } @@ -578,6 +581,32 @@ bool MergeJoin::addToOutputForLeftJoin() { } addOutputRow(left, i, right, j); } + + if (isLeftSemiFilterJoin(joinType_) && filter_) { + auto numRows = (rightEnd - rightStart); + SelectivityVector matchingRows{outputSize_, false}; + matchingRows.setValidRange( + (outputSize_ - numRows), outputSize_, true); + matchingRows.updateBounds(); + + evaluateFilter(matchingRows); + + auto processedRowNums = (outputSize_ - numRows); + + auto flag = true; + for (auto j = rightStart; j < rightEnd; ++j) { + auto rowIndex = processedRowNums + j - rightStart; + const bool passed = !decodedFilterResult_.isNullAt(rowIndex) && + decodedFilterResult_.valueAt(rowIndex); + if (passed) { + if (flag) { + flag = false; + } else { + joinTracker_->addMiss(rowIndex); + } + } + } + } } } } @@ -638,7 +667,7 @@ bool MergeJoin::addToOutputForRightJoin() { // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - if (isRightSemiFilterJoin(joinType_)) { + if (isRightSemiFilterJoin(joinType_) && !filter_) { // RightSemiFilter produce each row from the right at most once. leftEnd = leftStart + 1; } @@ -656,6 +685,32 @@ bool MergeJoin::addToOutputForRightJoin() { } addOutputRow(left, j, right, i); } + + if (isRightSemiFilterJoin(joinType_) && filter_) { + auto numRows = (leftEnd - leftStart); + SelectivityVector matchingRows{outputSize_, false}; + matchingRows.setValidRange( + (outputSize_ - numRows), outputSize_, true); + matchingRows.updateBounds(); + + evaluateFilter(matchingRows); + + auto processedRowNums = (outputSize_ - numRows); + + auto flag = true; + for (auto j = leftStart; j < leftEnd; ++j) { + auto rowIndex = processedRowNums + j - leftStart; + const bool passed = !decodedFilterResult_.isNullAt(rowIndex) && + decodedFilterResult_.valueAt(rowIndex); + if (passed) { + if (flag) { + flag = false; + } else { + joinTracker_->addMiss(rowIndex); + } + } + } + } } } } @@ -1147,7 +1202,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // If all matches for a given left-side row fail the filter, add a row to // the output with nulls for the right-side columns. auto onMiss = [&](auto row) { - if (!isAntiJoin(joinType_)) { + if (!isAntiJoin(joinType_) && !isRightSemiFilterJoin(joinType_) && + !isLeftSemiFilterJoin(joinType_)) { rawIndices[numPassed++] = row; if (isFullJoin(joinType_)) { @@ -1235,9 +1291,12 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { } } } else { - // This row doesn't have a match on the right side. Keep it - // unconditionally. - rawIndices[numPassed++] = i; + if (!isLeftSemiFilterJoin(joinType_) && + !isRightSemiFilterJoin(joinType_)) { + // This row doesn't have a match on the right side. Keep it + // unconditionally. + rawIndices[numPassed++] = i; + } } } diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 72062c6416790..d4d3ffcbcd228 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -756,6 +756,51 @@ TEST_F(MergeJoinTest, semiJoin) { core::JoinType::kRightSemiFilter); } +TEST_F(MergeJoinTest, leftSemiJoin) { + auto left = makeRowVector( + {"t0", "t1"}, + {makeNullableFlatVector({"val1b", "val1b", "val1b"}), + makeNullableFlatVector({12, 12, 12})}); + + auto right = makeRowVector( + {"u0", "u1"}, + {makeNullableFlatVector({"val1b", "val1b", "val1b"}), + makeNullableFlatVector({12, 16, 16})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto testSemiJoin = [&](const std::string& filter, + const std::string& sql, + const std::vector& outputLayout, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + filter, + outputLayout, + joinType) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_).assertResults(sql); + }; + + testSemiJoin( + "t1 < u1", + "SELECT t0, t1 FROM t where t0 IN (SELECT u0 from u where t1 < u1)", + {"t0", "t1"}, + core::JoinType::kLeftSemiFilter); + testSemiJoin( + "t1 < u1", + "SELECT u0, u1 FROM u where u0 IN (SELECT t0 from t where t1 < u1)", + {"u0", "u1"}, + core::JoinType::kRightSemiFilter); +} + TEST_F(MergeJoinTest, rightJoin) { auto left = makeRowVector( {"t0"},