Skip to content

Commit

Permalink
Support left and right semi project join in smj
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Oct 9, 2024
1 parent 63c848d commit b460dc5
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 18 deletions.
2 changes: 2 additions & 0 deletions velox/core/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,8 @@ bool MergeJoinNode::isSupported(core::JoinType joinType) {
case core::JoinType::kRightSemiFilter:
case core::JoinType::kAnti:
case core::JoinType::kFull:
case core::JoinType::kLeftSemiProject:
case core::JoinType::kRightSemiProject:
return true;

default:
Expand Down
120 changes: 103 additions & 17 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ void MergeJoin::initialize() {
}
}

if (joinNode_->isRightSemiFilterJoin()) {
if (joinNode_->isRightSemiFilterJoin() ||
joinNode_->isRightSemiProjectJoin()) {
VELOX_USER_CHECK(
leftProjections_.empty(),
"The left side projections should be empty for right semi join");
Expand All @@ -79,7 +80,7 @@ void MergeJoin::initialize() {
}
}

if (joinNode_->isLeftSemiFilterJoin()) {
if (joinNode_->isLeftSemiFilterJoin() || joinNode_->isLeftSemiProjectJoin()) {
VELOX_USER_CHECK(
rightProjections_.empty(),
"The right side projections should be empty for left semi join");
Expand All @@ -89,10 +90,14 @@ void MergeJoin::initialize() {
initializeFilter(joinNode_->filter(), leftType, rightType);

if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() ||
joinNode_->isRightJoin() || joinNode_->isFullJoin()) {
joinNode_->isRightJoin() || joinNode_->isFullJoin() ||
joinNode_->isLeftSemiProjectJoin() ||
joinNode_->isRightSemiProjectJoin()) {
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}
} else if (joinNode_->isAntiJoin()) {
} else if (
joinNode_->isAntiJoin() || joinNode_->isLeftSemiProjectJoin() ||
joinNode_->isRightSemiProjectJoin()) {
// Anti join needs to track the left side rows that have no match on the
// right.
joinTracker_ = JoinTracker(outputBatchSize_, pool());
Expand Down Expand Up @@ -286,7 +291,8 @@ void MergeJoin::addOutputRowForLeftJoin(
const RowVectorPtr& left,
vector_size_t leftIndex) {
VELOX_USER_CHECK(
isLeftJoin(joinType_) || isAntiJoin(joinType_) || isFullJoin(joinType_));
isLeftJoin(joinType_) || isAntiJoin(joinType_) || isFullJoin(joinType_) ||
isLeftSemiProjectJoin(joinType_));
rawLeftIndices_[outputSize_] = leftIndex;

for (const auto& projection : rightProjections_) {
Expand All @@ -305,7 +311,9 @@ void MergeJoin::addOutputRowForLeftJoin(
void MergeJoin::addOutputRowForRightJoin(
const RowVectorPtr& right,
vector_size_t rightIndex) {
VELOX_USER_CHECK(isRightJoin(joinType_) || isFullJoin(joinType_));
VELOX_USER_CHECK(
isRightJoin(joinType_) || isFullJoin(joinType_) ||
isRightSemiProjectJoin(joinType_));
rawRightIndices_[outputSize_] = rightIndex;

for (const auto& projection : leftProjections_) {
Expand Down Expand Up @@ -358,7 +366,7 @@ void MergeJoin::addOutputRow(
copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_);

if (joinTracker_) {
if (isRightJoin(joinType_)) {
if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
// Record right-side row with a match on the left-side.
joinTracker_->addMatch(right, rightIndex, outputSize_);
} else {
Expand All @@ -370,12 +378,18 @@ void MergeJoin::addOutputRow(

// Anti join needs to track the left side rows that have no match on the
// right.
if (isAntiJoin(joinType_)) {
if (isAntiJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) {
VELOX_CHECK(joinTracker_);
// Record left-side row with a match on the right-side.
joinTracker_->addMatch(left, leftIndex, outputSize_);
}

if (isRightSemiProjectJoin(joinType_)) {
VELOX_CHECK(joinTracker_);
// Record right-side row with a match on the left-side.
joinTracker_->addMatch(right, rightIndex, outputSize_);
}

++outputSize_;
}

Expand Down Expand Up @@ -449,13 +463,20 @@ bool MergeJoin::prepareOutput(
isRightFlattened_ = false;
}
currentRight_ = right;
if (isRightSemiProjectJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) {
localColumns[outputType_->size() - 1] = BaseVector::create(
outputType_->childAt(outputType_->size() - 1),
outputBatchSize_,
operatorCtx_->pool());
}

output_ = std::make_shared<RowVector>(
operatorCtx_->pool(),
outputType_,
nullptr,
outputBatchSize_,
std::move(localColumns));

outputSize_ = 0;

if (filterInput_ != nullptr) {
Expand Down Expand Up @@ -507,7 +528,8 @@ bool MergeJoin::prepareOutput(
}

bool MergeJoin::addToOutput() {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) {
if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
return addToOutputForRightJoin();
} else {
return addToOutputForLeftJoin();
Expand Down Expand Up @@ -560,7 +582,8 @@ bool MergeJoin::addToOutputForLeftJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isLeftSemiFilterJoin(joinType_)) {
if (isLeftSemiFilterJoin(joinType_) ||
isLeftSemiProjectJoin(joinType_)) {
// LeftSemiFilter produce each row from the left at most once.
rightEnd = rightStart + 1;
}
Expand Down Expand Up @@ -638,7 +661,8 @@ bool MergeJoin::addToOutputForRightJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isRightSemiFilterJoin(joinType_)) {
if (isRightSemiFilterJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
// RightSemiFilter produce each row from the right at most once.
leftEnd = leftStart + 1;
}
Expand Down Expand Up @@ -719,6 +743,26 @@ RowVectorPtr MergeJoin::filterOutputForAntiJoin(const RowVectorPtr& output) {
return wrap(numPassed, indices, output);
}

RowVectorPtr MergeJoin::filterOutputForSemiProject(const RowVectorPtr& output) {
auto numRows = output->size();
const auto& filterRows = joinTracker_->matchingRows(numRows);

auto lastChildren = output->children().back();
auto flatMatch = lastChildren->as<FlatVector<bool>>();
flatMatch->resize(numRows);
auto rawValues = flatMatch->mutableRawValues<uint64_t>();

for (auto i = 0; i < numRows; i++) {
if (filterRows.isValid(i)) {
bits::setBit(rawValues, i, true);
} else {
bits::setBit(rawValues, i, false);
}
}

return output;
}

RowVectorPtr MergeJoin::getOutput() {
// Make sure to have is-blocked or needs-input as true if returning null
// output. Otherwise, Driver assumes the operator is finished.
Expand All @@ -730,6 +774,7 @@ RowVectorPtr MergeJoin::getOutput() {

for (;;) {
auto output = doGetOutput();

if (output != nullptr && output->size() > 0) {
if (filter_) {
output = applyFilter(output);
Expand All @@ -751,6 +796,16 @@ RowVectorPtr MergeJoin::getOutput() {

// No rows survived the filter for anti join. Get more rows.
continue;
} else if (
isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
output = filterOutputForSemiProject(output);
if (output) {
return output;
}

// No rows survived the filter. Get more rows.
continue;
} else {
return output;
}
Expand Down Expand Up @@ -871,7 +926,8 @@ RowVectorPtr MergeJoin::doGetOutput() {
}

if (!input_ || !rightInput_) {
if (isLeftJoin(joinType_) || isAntiJoin(joinType_)) {
if (isLeftJoin(joinType_) || isAntiJoin(joinType_) ||
isLeftSemiProjectJoin(joinType_)) {
if (input_ && noMoreRightInput_) {
// If output_ is currently wrapping a different buffer, return it
// first.
Expand All @@ -898,7 +954,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
output_->resize(outputSize_);
return std::move(output_);
}
} else if (isRightJoin(joinType_)) {
} else if (isRightJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
if (rightInput_ && noMoreInput_) {
// If output_ is currently wrapping a different buffer, return it
// first.
Expand Down Expand Up @@ -1004,7 +1060,7 @@ RowVectorPtr MergeJoin::doGetOutput() {
// Catch up input_ with rightInput_.
while (compareResult < 0) {
if (isLeftJoin(joinType_) || isAntiJoin(joinType_) ||
isFullJoin(joinType_)) {
isFullJoin(joinType_) || isLeftSemiProjectJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(input_, nullptr)) {
Expand All @@ -1031,7 +1087,8 @@ RowVectorPtr MergeJoin::doGetOutput() {

// Catch up rightInput_ with input_.
while (compareResult > 0) {
if (isRightJoin(joinType_) || isFullJoin(joinType_)) {
if (isRightJoin(joinType_) || isFullJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
// If output_ is currently wrapping a different buffer, return it
// first.
if (prepareOutput(nullptr, rightInput_)) {
Expand Down Expand Up @@ -1139,17 +1196,36 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {

if (!filterRows.hasSelections()) {
// No matches in the output, no need to evaluate the filter.
return output;
if (isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
return filterOutputForSemiProject(output);
} else {
return output;
}
}

evaluateFilter(filterRows);

FlatVector<bool>* flatMatch{nullptr};
uint64_t* rawValues;

if (isLeftSemiProjectJoin(joinType_) || isRightSemiProjectJoin(joinType_)) {
flatMatch = output->children().back()->as<FlatVector<bool>>();
flatMatch->resize(numRows);
rawValues = flatMatch->mutableRawValues<uint64_t>();
}

// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
auto onMiss = [&](auto row) {
if (!isAntiJoin(joinType_)) {
rawIndices[numPassed++] = row;

if (isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
bits::setBit(rawValues, row, false);
}

if (isFullJoin(joinType_)) {
// For filtered rows, it is necessary to insert additional data
// to ensure the result set is complete. Specifically, we
Expand Down Expand Up @@ -1204,7 +1280,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
auto target = fullOuterOutput->childAt(projection.outputChannel);
target->setNull(row + 1, true);
}
} else if (!isRightJoin(joinType_)) {
} else if (
!isRightJoin(joinType_) && !isRightSemiProjectJoin(joinType_)) {
for (auto& projection : rightProjections_) {
auto target = output->childAt(projection.outputChannel);
target->setNull(row, true);
Expand All @@ -1231,13 +1308,22 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
}
} else {
if (passed) {
if (isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
bits::setBit(rawValues, i, true);
}
rawIndices[numPassed++] = i;
}
}
} else {
// This row doesn't have a match on the right side. Keep it
// unconditionally.
rawIndices[numPassed++] = i;

if (isLeftSemiProjectJoin(joinType_) ||
isRightSemiProjectJoin(joinType_)) {
bits::setBit(rawValues, i, false);
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions velox/exec/MergeJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ class MergeJoin : public Operator {
/// rows from the left side that have a match on the right.
RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output);

/// Return each row from the left or right side with a boolean flag indicating
/// whether there exists a match on the right or left side.
RowVectorPtr filterOutputForSemiProject(const RowVectorPtr& output);

/// As we populate the results of the join, we track whether a given
/// output row is a result of a match between left and right sides or a miss.
/// We use JoinTracker::addMatch and addMiss methods for that.
Expand Down
54 changes: 54 additions & 0 deletions velox/exec/tests/MergeJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,60 @@ TEST_F(MergeJoinTest, semiJoin) {
core::JoinType::kRightSemiFilter);
}

TEST_F(MergeJoinTest, semiJoinProjection) {
auto left = makeRowVector(
{"t0"}, {makeNullableFlatVector<int64_t>({1, 2, 2, 6, std::nullopt})});

auto right = makeRowVector(
{"u0"},
{makeNullableFlatVector<int64_t>(
{1, 2, 2, 7, std::nullopt, std::nullopt})});

createDuckDbTable("t", {left});
createDuckDbTable("u", {right});

auto testSemiJoin = [&](const std::string& filter,
const std::string& sql,
const std::vector<std::string>& outputLayout,
core::JoinType joinType) {
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan =
PlanBuilder(planNodeIdGenerator)
.values({left})
.mergeJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator).values({right}).planNode(),
filter,
outputLayout,
joinType)
.planNode();
AssertQueryBuilder(plan, duckDbQueryRunner_).assertResults(sql);
};

testSemiJoin(
"",
"SELECT t.t0, EXISTS (SELECT * FROM u WHERE t.t0 = u.u0) FROM t",
{"t0", "match"},
core::JoinType::kLeftSemiProject);
testSemiJoin(
"",
"SELECT u0, u0 IN (SELECT * FROM t where t.t0 = u.u0) FROM u",
{"u0", "match"},
core::JoinType::kRightSemiProject);

testSemiJoin(
"t0 > 1",
"SELECT t.t0, EXISTS (SELECT * FROM u WHERE t0 = u0 and t.t0 > 1) FROM t",
{"t0", "match"},
core::JoinType::kLeftSemiProject);
testSemiJoin(
"u0 > 1",
"SELECT u0, u0 IN (SELECT * FROM t where t0 = u0 and u0 > 1) FROM u",
{"u0", "match"},
core::JoinType::kRightSemiProject);
}

TEST_F(MergeJoinTest, rightJoin) {
auto left = makeRowVector(
{"t0"},
Expand Down
Loading

0 comments on commit b460dc5

Please sign in to comment.