diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index 9e3c1fc0fad3..a00da95cc6d7 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -1142,6 +1142,8 @@ bool MergeJoinNode::isSupported(core::JoinType joinType) { case core::JoinType::kRightSemiFilter: case core::JoinType::kAnti: case core::JoinType::kFull: + case core::JoinType::kLeftSemiProject: + case core::JoinType::kRightSemiProject: return true; default: diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 685eb1d5cf07..1899d5b9febd 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -65,7 +65,8 @@ void MergeJoin::initialize() { } } - if (joinNode_->isRightSemiFilterJoin()) { + if (joinNode_->isRightSemiFilterJoin() || + joinNode_->isRightSemiProjectJoin()) { VELOX_USER_CHECK( leftProjections_.empty(), "The left side projections should be empty for right semi join"); @@ -79,7 +80,7 @@ void MergeJoin::initialize() { } } - if (joinNode_->isLeftSemiFilterJoin()) { + if (joinNode_->isLeftSemiFilterJoin() || joinNode_->isLeftSemiProjectJoin()) { VELOX_USER_CHECK( rightProjections_.empty(), "The right side projections should be empty for left semi join"); @@ -89,10 +90,14 @@ void MergeJoin::initialize() { initializeFilter(joinNode_->filter(), leftType, rightType); if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || - joinNode_->isRightJoin() || joinNode_->isFullJoin()) { + joinNode_->isRightJoin() || joinNode_->isFullJoin() || + joinNode_->isLeftSemiProjectJoin() || + joinNode_->isRightSemiProjectJoin()) { joinTracker_ = JoinTracker(outputBatchSize_, pool()); } - } else if (joinNode_->isAntiJoin()) { + } else if ( + joinNode_->isAntiJoin() || joinNode_->isLeftSemiProjectJoin() || + joinNode_->isRightSemiProjectJoin()) { // Anti join needs to track the left side rows that have no match on the // right. joinTracker_ = JoinTracker(outputBatchSize_, pool()); @@ -286,7 +291,8 @@ void MergeJoin::addOutputRowForLeftJoin( const RowVectorPtr& left, vector_size_t leftIndex) { VELOX_USER_CHECK( - isLeftJoin(joinType_) || isAntiJoin(joinType_) || isFullJoin(joinType_)); + isLeftJoin(joinType_) || isAntiJoin(joinType_) || isFullJoin(joinType_) || + isLeftSemiProjectJoin(joinType_)); rawLeftIndices_[outputSize_] = leftIndex; for (const auto& projection : rightProjections_) { @@ -305,7 +311,9 @@ void MergeJoin::addOutputRowForLeftJoin( void MergeJoin::addOutputRowForRightJoin( const RowVectorPtr& right, vector_size_t rightIndex) { - VELOX_USER_CHECK(isRightJoin(joinType_) || isFullJoin(joinType_)); + VELOX_USER_CHECK( + isRightJoin(joinType_) || isFullJoin(joinType_) || + isRightSemiProjectJoin(joinType_)); rawRightIndices_[outputSize_] = rightIndex; for (const auto& projection : leftProjections_) { @@ -358,7 +366,7 @@ void MergeJoin::addOutputRow( copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_); if (joinTracker_) { - if (isRightJoin(joinType_)) { + if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) { // Record right-side row with a match on the left-side. joinTracker_->addMatch(right, rightIndex, outputSize_); } else { @@ -370,12 +378,18 @@ void MergeJoin::addOutputRow( // Anti join needs to track the left side rows that have no match on the // right. - if (isAntiJoin(joinType_)) { + if (isAntiJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) { VELOX_CHECK(joinTracker_); // Record left-side row with a match on the right-side. joinTracker_->addMatch(left, leftIndex, outputSize_); } + if (isRightSemiProjectJoin(joinType_)) { + VELOX_CHECK(joinTracker_); + // Record right-side row with a match on the left-side. + joinTracker_->addMatch(right, rightIndex, outputSize_); + } + ++outputSize_; } @@ -449,6 +463,12 @@ bool MergeJoin::prepareOutput( isRightFlattened_ = false; } currentRight_ = right; + if (isRightSemiProjectJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) { + localColumns[outputType_->size() - 1] = BaseVector::create( + outputType_->childAt(outputType_->size() - 1), + outputBatchSize_, + operatorCtx_->pool()); + } output_ = std::make_shared( operatorCtx_->pool(), @@ -456,6 +476,7 @@ bool MergeJoin::prepareOutput( nullptr, outputBatchSize_, std::move(localColumns)); + outputSize_ = 0; if (filterInput_ != nullptr) { @@ -507,7 +528,8 @@ bool MergeJoin::prepareOutput( } bool MergeJoin::addToOutput() { - if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { + if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { return addToOutputForRightJoin(); } else { return addToOutputForLeftJoin(); @@ -560,7 +582,8 @@ bool MergeJoin::addToOutputForLeftJoin() { // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - if (isLeftSemiFilterJoin(joinType_)) { + if (isLeftSemiFilterJoin(joinType_) || + isLeftSemiProjectJoin(joinType_)) { // LeftSemiFilter produce each row from the left at most once. rightEnd = rightStart + 1; } @@ -638,7 +661,8 @@ bool MergeJoin::addToOutputForRightJoin() { // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - if (isRightSemiFilterJoin(joinType_)) { + if (isRightSemiFilterJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { // RightSemiFilter produce each row from the right at most once. leftEnd = leftStart + 1; } @@ -719,6 +743,26 @@ RowVectorPtr MergeJoin::filterOutputForAntiJoin(const RowVectorPtr& output) { return wrap(numPassed, indices, output); } +RowVectorPtr MergeJoin::filterOutputForSemiProject(const RowVectorPtr& output) { + auto numRows = output->size(); + const auto& filterRows = joinTracker_->matchingRows(numRows); + + auto lastChildren = output->children().back(); + auto flatMatch = lastChildren->as>(); + flatMatch->resize(numRows); + auto rawValues = flatMatch->mutableRawValues(); + + for (auto i = 0; i < numRows; i++) { + if (filterRows.isValid(i)) { + bits::setBit(rawValues, i, true); + } else { + bits::setBit(rawValues, i, false); + } + } + + return output; +} + RowVectorPtr MergeJoin::getOutput() { // Make sure to have is-blocked or needs-input as true if returning null // output. Otherwise, Driver assumes the operator is finished. @@ -730,6 +774,7 @@ RowVectorPtr MergeJoin::getOutput() { for (;;) { auto output = doGetOutput(); + if (output != nullptr && output->size() > 0) { if (filter_) { output = applyFilter(output); @@ -751,6 +796,16 @@ RowVectorPtr MergeJoin::getOutput() { // No rows survived the filter for anti join. Get more rows. continue; + } else if ( + isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + output = filterOutputForSemiProject(output); + if (output) { + return output; + } + + // No rows survived the filter. Get more rows. + continue; } else { return output; } @@ -871,7 +926,8 @@ RowVectorPtr MergeJoin::doGetOutput() { } if (!input_ || !rightInput_) { - if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) { + if (isLeftJoin(joinType_) || isAntiJoin(joinType_) || + isLeftSemiProjectJoin(joinType_)) { if (input_ && noMoreRightInput_) { // If output_ is currently wrapping a different buffer, return it // first. @@ -898,7 +954,7 @@ RowVectorPtr MergeJoin::doGetOutput() { output_->resize(outputSize_); return std::move(output_); } - } else if (isRightJoin(joinType_)) { + } else if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) { if (rightInput_ && noMoreInput_) { // If output_ is currently wrapping a different buffer, return it // first. @@ -1004,7 +1060,7 @@ RowVectorPtr MergeJoin::doGetOutput() { // Catch up input_ with rightInput_. while (compareResult < 0) { if (isLeftJoin(joinType_) || isAntiJoin(joinType_) || - isFullJoin(joinType_)) { + isFullJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. if (prepareOutput(input_, nullptr)) { @@ -1031,7 +1087,8 @@ RowVectorPtr MergeJoin::doGetOutput() { // Catch up rightInput_ with input_. while (compareResult > 0) { - if (isRightJoin(joinType_) || isFullJoin(joinType_)) { + if (isRightJoin(joinType_) || isFullJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. if (prepareOutput(nullptr, rightInput_)) { @@ -1139,17 +1196,36 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (!filterRows.hasSelections()) { // No matches in the output, no need to evaluate the filter. - return output; + if (isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + return filterOutputForSemiProject(output); + } else { + return output; + } } evaluateFilter(filterRows); + FlatVector* flatMatch{nullptr}; + uint64_t* rawValues; + + if (isLeftSemiProjectJoin(joinType_) || isRightSemiProjectJoin(joinType_)) { + flatMatch = output->children().back()->as>(); + flatMatch->resize(numRows); + rawValues = flatMatch->mutableRawValues(); + } + // If all matches for a given left-side row fail the filter, add a row to // the output with nulls for the right-side columns. auto onMiss = [&](auto row) { if (!isAntiJoin(joinType_)) { rawIndices[numPassed++] = row; + if (isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + bits::setBit(rawValues, row, false); + } + if (isFullJoin(joinType_)) { // For filtered rows, it is necessary to insert additional data // to ensure the result set is complete. Specifically, we @@ -1204,7 +1280,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { auto target = fullOuterOutput->childAt(projection.outputChannel); target->setNull(row + 1, true); } - } else if (!isRightJoin(joinType_)) { + } else if ( + !isRightJoin(joinType_) && !isRightSemiProjectJoin(joinType_)) { for (auto& projection : rightProjections_) { auto target = output->childAt(projection.outputChannel); target->setNull(row, true); @@ -1231,6 +1308,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { } } else { if (passed) { + if (isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + bits::setBit(rawValues, i, true); + } rawIndices[numPassed++] = i; } } @@ -1238,6 +1319,11 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // This row doesn't have a match on the right side. Keep it // unconditionally. rawIndices[numPassed++] = i; + + if (isLeftSemiProjectJoin(joinType_) || + isRightSemiProjectJoin(joinType_)) { + bits::setBit(rawValues, i, false); + } } } diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index 3530316b90c7..c731be05124a 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -247,6 +247,10 @@ class MergeJoin : public Operator { /// rows from the left side that have a match on the right. RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output); + /// Return each row from the left or right side with a boolean flag indicating + /// whether there exists a match on the right or left side. + RowVectorPtr filterOutputForSemiProject(const RowVectorPtr& output); + /// As we populate the results of the join, we track whether a given /// output row is a result of a match between left and right sides or a miss. /// We use JoinTracker::addMatch and addMiss methods for that. diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 72062c641679..72f7264ce57a 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -756,6 +756,60 @@ TEST_F(MergeJoinTest, semiJoin) { core::JoinType::kRightSemiFilter); } +TEST_F(MergeJoinTest, semiJoinProjection) { + auto left = makeRowVector( + {"t0"}, {makeNullableFlatVector({1, 2, 2, 6, std::nullopt})}); + + auto right = makeRowVector( + {"u0"}, + {makeNullableFlatVector( + {1, 2, 2, 7, std::nullopt, std::nullopt})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto testSemiJoin = [&](const std::string& filter, + const std::string& sql, + const std::vector& outputLayout, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + filter, + outputLayout, + joinType) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_).assertResults(sql); + }; + + testSemiJoin( + "", + "SELECT t.t0, EXISTS (SELECT * FROM u WHERE t.t0 = u.u0) FROM t", + {"t0", "match"}, + core::JoinType::kLeftSemiProject); + testSemiJoin( + "", + "SELECT u0, u0 IN (SELECT * FROM t where t.t0 = u.u0) FROM u", + {"u0", "match"}, + core::JoinType::kRightSemiProject); + + testSemiJoin( + "t0 > 1", + "SELECT t.t0, EXISTS (SELECT * FROM u WHERE t0 = u0 and t.t0 > 1) FROM t", + {"t0", "match"}, + core::JoinType::kLeftSemiProject); + testSemiJoin( + "u0 > 1", + "SELECT u0, u0 IN (SELECT * FROM t where t0 = u0 and u0 > 1) FROM u", + {"u0", "match"}, + core::JoinType::kRightSemiProject); +} + TEST_F(MergeJoinTest, rightJoin) { auto left = makeRowVector( {"t0"}, diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index 31c3a5d39cba..e54848543768 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -1423,7 +1423,23 @@ PlanBuilder& PlanBuilder::mergeJoin( if (!filter.empty()) { filterExpr = parseExpr(filter, resultType, options_, pool_); } - auto outputType = extract(resultType, outputLayout); + RowTypePtr outputType; + if (isLeftSemiProjectJoin(joinType) || isRightSemiProjectJoin(joinType)) { + std::vector names = outputLayout; + + // Last column in 'outputLayout' must be a boolean 'match'. + std::vector types; + types.reserve(outputLayout.size()); + for (auto i = 0; i < outputLayout.size() - 1; ++i) { + types.emplace_back(resultType->findChild(outputLayout[i])); + } + types.emplace_back(BOOLEAN()); + + outputType = ROW(std::move(names), std::move(types)); + } else { + outputType = extract(resultType, outputLayout); + } + auto leftKeyFields = fields(leftType, leftKeys); auto rightKeyFields = fields(rightType, rightKeys);