Skip to content

Commit

Permalink
semi join fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Oct 12, 2024
1 parent 89dcf38 commit e4b18f9
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 7 deletions.
73 changes: 66 additions & 7 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "velox/exec/MergeJoin.h"
#include <iostream>
#include "velox/exec/OperatorUtils.h"
#include "velox/exec/Task.h"
#include "velox/expression/FieldReference.h"
Expand Down Expand Up @@ -89,7 +90,9 @@ void MergeJoin::initialize() {
initializeFilter(joinNode_->filter(), leftType, rightType);

if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() ||
joinNode_->isRightJoin() || joinNode_->isFullJoin()) {
joinNode_->isRightJoin() || joinNode_->isFullJoin() ||
joinNode_->isLeftSemiFilterJoin() ||
joinNode_->isRightSemiFilterJoin()) {
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}
} else if (joinNode_->isAntiJoin()) {
Expand Down Expand Up @@ -560,7 +563,7 @@ bool MergeJoin::addToOutputForLeftJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isLeftSemiFilterJoin(joinType_)) {
if (isLeftSemiFilterJoin(joinType_) && !filter_) {
// LeftSemiFilter produce each row from the left at most once.
rightEnd = rightStart + 1;
}
Expand All @@ -578,6 +581,32 @@ bool MergeJoin::addToOutputForLeftJoin() {
}
addOutputRow(left, i, right, j);
}

if (isLeftSemiFilterJoin(joinType_) && filter_) {
auto numRows = (rightEnd - rightStart);
SelectivityVector matchingRows{outputSize_, false};
matchingRows.setValidRange(
(outputSize_ - numRows), outputSize_, true);
matchingRows.updateBounds();

evaluateFilter(matchingRows);

auto processedRowNums = (outputSize_ - numRows);

auto flag = true;
for (auto j = rightStart; j < rightEnd; ++j) {
auto rowIndex = processedRowNums + j - rightStart;
const bool passed = !decodedFilterResult_.isNullAt(rowIndex) &&
decodedFilterResult_.valueAt<bool>(rowIndex);
if (passed) {
if (flag) {
flag = false;
} else {
joinTracker_->addMiss(rowIndex);
}
}
}
}
}
}
}
Expand Down Expand Up @@ -638,7 +667,7 @@ bool MergeJoin::addToOutputForRightJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isRightSemiFilterJoin(joinType_)) {
if (isRightSemiFilterJoin(joinType_) && !filter_) {
// RightSemiFilter produce each row from the right at most once.
leftEnd = leftStart + 1;
}
Expand All @@ -656,6 +685,32 @@ bool MergeJoin::addToOutputForRightJoin() {
}
addOutputRow(left, j, right, i);
}

if (isRightSemiFilterJoin(joinType_) && filter_) {
auto numRows = (leftEnd - leftStart);
SelectivityVector matchingRows{outputSize_, false};
matchingRows.setValidRange(
(outputSize_ - numRows), outputSize_, true);
matchingRows.updateBounds();

evaluateFilter(matchingRows);

auto processedRowNums = (outputSize_ - numRows);

auto flag = true;
for (auto j = leftStart; j < leftEnd; ++j) {
auto rowIndex = processedRowNums + j - leftStart;
const bool passed = !decodedFilterResult_.isNullAt(rowIndex) &&
decodedFilterResult_.valueAt<bool>(rowIndex);
if (passed) {
if (flag) {
flag = false;
} else {
joinTracker_->addMiss(rowIndex);
}
}
}
}
}
}
}
Expand Down Expand Up @@ -1147,7 +1202,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
auto onMiss = [&](auto row) {
if (!isAntiJoin(joinType_)) {
if (!isAntiJoin(joinType_) && !isRightSemiFilterJoin(joinType_) &&
!isLeftSemiFilterJoin(joinType_)) {
rawIndices[numPassed++] = row;

if (isFullJoin(joinType_)) {
Expand Down Expand Up @@ -1235,9 +1291,12 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
}
}
} else {
// This row doesn't have a match on the right side. Keep it
// unconditionally.
rawIndices[numPassed++] = i;
if (!isLeftSemiFilterJoin(joinType_) &&
!isRightSemiFilterJoin(joinType_)) {
// This row doesn't have a match on the right side. Keep it
// unconditionally.
rawIndices[numPassed++] = i;
}
}
}

Expand Down
45 changes: 45 additions & 0 deletions velox/exec/tests/MergeJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,51 @@ TEST_F(MergeJoinTest, semiJoin) {
core::JoinType::kRightSemiFilter);
}

TEST_F(MergeJoinTest, leftSemiJoin) {
auto left = makeRowVector(
{"t0", "t1"},
{makeNullableFlatVector<std::string>({"val1b", "val1b", "val1b"}),
makeNullableFlatVector<int64_t>({12, 12, 12})});

auto right = makeRowVector(
{"u0", "u1"},
{makeNullableFlatVector<std::string>({"val1b", "val1b", "val1b"}),
makeNullableFlatVector<int64_t>({12, 16, 16})});

createDuckDbTable("t", {left});
createDuckDbTable("u", {right});

auto testSemiJoin = [&](const std::string& filter,
const std::string& sql,
const std::vector<std::string>& outputLayout,
core::JoinType joinType) {
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan =
PlanBuilder(planNodeIdGenerator)
.values({left})
.mergeJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator).values({right}).planNode(),
filter,
outputLayout,
joinType)
.planNode();
AssertQueryBuilder(plan, duckDbQueryRunner_).assertResults(sql);
};

testSemiJoin(
"t1 < u1",
"SELECT t0, t1 FROM t where t0 IN (SELECT u0 from u where t1 < u1)",
{"t0", "t1"},
core::JoinType::kLeftSemiFilter);
testSemiJoin(
"t1 < u1",
"SELECT u0, u1 FROM u where u0 IN (SELECT t0 from t where t1 < u1)",
{"u0", "u1"},
core::JoinType::kRightSemiFilter);
}

TEST_F(MergeJoinTest, rightJoin) {
auto left = makeRowVector(
{"t0"},
Expand Down

0 comments on commit e4b18f9

Please sign in to comment.