Skip to content

Commit

Permalink
fix semi join result mismatch issue
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Oct 15, 2024
1 parent d141f96 commit f57c265
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 7 deletions.
90 changes: 83 additions & 7 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ void MergeJoin::initialize() {
initializeFilter(joinNode_->filter(), leftType, rightType);

if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() ||
joinNode_->isRightJoin() || joinNode_->isFullJoin()) {
joinNode_->isRightJoin() || joinNode_->isFullJoin() ||
joinNode_->isLeftSemiFilterJoin() ||
joinNode_->isRightSemiFilterJoin()) {
joinTracker_ = JoinTracker(outputBatchSize_, pool());
}
} else if (joinNode_->isAntiJoin()) {
Expand Down Expand Up @@ -543,6 +545,12 @@ bool MergeJoin::addToOutputForLeftJoin() {
: rightMatch_->startIndex;

auto numRights = rightMatch_->inputs.size();

if (isLeftSemiFilterJoin(joinType_) && !filter_) {
// LeftSemiFilter produce each row from the left at most once.
numRights = 1;
}

for (size_t r = firstRightBatch; r < numRights; ++r) {
auto right = rightMatch_->inputs[r];
auto rightStart = r == firstRightBatch ? rightStartIndex : 0;
Expand All @@ -560,7 +568,7 @@ bool MergeJoin::addToOutputForLeftJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isLeftSemiFilterJoin(joinType_)) {
if (isLeftSemiFilterJoin(joinType_) && !filter_) {
// LeftSemiFilter produce each row from the left at most once.
rightEnd = rightStart + 1;
}
Expand All @@ -578,6 +586,35 @@ bool MergeJoin::addToOutputForLeftJoin() {
}
addOutputRow(left, i, right, j);
}

if (isLeftSemiFilterJoin(joinType_) && filter_) {
auto numRows = (rightEnd - rightStart);
SelectivityVector matchingRows{outputSize_, false};
matchingRows.setValidRange(
(outputSize_ - numRows), outputSize_, true);
matchingRows.updateBounds();

evaluateFilter(matchingRows);

auto processedRowNums = (outputSize_ - numRows);

auto firstMatchedRow = false;
for (auto j = rightStart; j < rightEnd; ++j) {
auto rowIndex = processedRowNums + j - rightStart;
const bool passed = !decodedFilterResult_.isNullAt(rowIndex) &&
decodedFilterResult_.valueAt<bool>(rowIndex);
if (passed) {
if (!firstMatchedRow) {
firstMatchedRow = true;
} else {
joinTracker_->addMiss(rowIndex);
}
}
}
if (firstMatchedRow) {
break;
}
}
}
}
}
Expand Down Expand Up @@ -622,6 +659,12 @@ bool MergeJoin::addToOutputForRightJoin() {
: leftMatch_->startIndex;

auto numLefts = leftMatch_->inputs.size();

if (isRightSemiFilterJoin(joinType_) && !filter_) {
// RightSemiFilter produce each row from the left at most once.
numRights = 1;
}

for (size_t l = firstLeftBatch; l < numLefts; ++l) {
auto left = leftMatch_->inputs[l];
auto leftStart = l == firstLeftBatch ? leftStartIndex : 0;
Expand All @@ -638,7 +681,7 @@ bool MergeJoin::addToOutputForRightJoin() {
// one match on the other side, we could explore specialized algorithms
// or data structures that short-circuit the join process once a match
// is found.
if (isRightSemiFilterJoin(joinType_)) {
if (isRightSemiFilterJoin(joinType_) && !filter_) {
// RightSemiFilter produce each row from the right at most once.
leftEnd = leftStart + 1;
}
Expand All @@ -656,6 +699,35 @@ bool MergeJoin::addToOutputForRightJoin() {
}
addOutputRow(left, j, right, i);
}

if (isRightSemiFilterJoin(joinType_) && filter_) {
auto numRows = (leftEnd - leftStart);
SelectivityVector matchingRows{outputSize_, false};
matchingRows.setValidRange(
(outputSize_ - numRows), outputSize_, true);
matchingRows.updateBounds();

evaluateFilter(matchingRows);

auto processedRowNums = (outputSize_ - numRows);

auto firstMatchedRow = false;
for (auto j = leftStart; j < leftEnd; ++j) {
auto rowIndex = processedRowNums + j - leftStart;
const bool passed = !decodedFilterResult_.isNullAt(rowIndex) &&
decodedFilterResult_.valueAt<bool>(rowIndex);
if (passed) {
if (!firstMatchedRow) {
firstMatchedRow = true;
} else {
joinTracker_->addMiss(rowIndex);
}
}
}
if (firstMatchedRow) {
break;
}
}
}
}
}
Expand Down Expand Up @@ -1147,7 +1219,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
// If all matches for a given left-side row fail the filter, add a row to
// the output with nulls for the right-side columns.
auto onMiss = [&](auto row) {
if (!isAntiJoin(joinType_)) {
if (!isAntiJoin(joinType_) && !isRightSemiFilterJoin(joinType_) &&
!isLeftSemiFilterJoin(joinType_)) {
rawIndices[numPassed++] = row;

if (isFullJoin(joinType_)) {
Expand Down Expand Up @@ -1235,9 +1308,12 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
}
}
} else {
// This row doesn't have a match on the right side. Keep it
// unconditionally.
rawIndices[numPassed++] = i;
if (!isLeftSemiFilterJoin(joinType_) &&
!isRightSemiFilterJoin(joinType_)) {
// This row doesn't have a match on the right side. Keep it
// unconditionally.
rawIndices[numPassed++] = i;
}
}
}

Expand Down
45 changes: 45 additions & 0 deletions velox/exec/tests/MergeJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,51 @@ TEST_F(MergeJoinTest, semiJoin) {
core::JoinType::kRightSemiFilter);
}

TEST_F(MergeJoinTest, leftSemiJoin) {
auto left = makeRowVector(
{"t0", "t1"},
{makeNullableFlatVector<std::string>({"val1b", "val1b", "val1b"}),
makeNullableFlatVector<int64_t>({12, 12, 12})});

auto right = makeRowVector(
{"u0", "u1"},
{makeNullableFlatVector<std::string>({"val1b", "val1b", "val1b"}),
makeNullableFlatVector<int64_t>({12, 16, 16})});

createDuckDbTable("t", {left});
createDuckDbTable("u", {right});

auto testSemiJoin = [&](const std::string& filter,
const std::string& sql,
const std::vector<std::string>& outputLayout,
core::JoinType joinType) {
auto planNodeIdGenerator = std::make_shared<core::PlanNodeIdGenerator>();
auto plan =
PlanBuilder(planNodeIdGenerator)
.values({left})
.mergeJoin(
{"t0"},
{"u0"},
PlanBuilder(planNodeIdGenerator).values({right}).planNode(),
filter,
outputLayout,
joinType)
.planNode();
AssertQueryBuilder(plan, duckDbQueryRunner_).assertResults(sql);
};

testSemiJoin(
"t1 < u1",
"SELECT t0, t1 FROM t where t0 IN (SELECT u0 from u where t1 < u1)",
{"t0", "t1"},
core::JoinType::kLeftSemiFilter);
testSemiJoin(
"t1 < u1",
"SELECT u0, u1 FROM u where u0 IN (SELECT t0 from t where t1 < u1)",
{"u0", "u1"},
core::JoinType::kRightSemiFilter);
}

TEST_F(MergeJoinTest, rightJoin) {
auto left = makeRowVector(
{"t0"},
Expand Down

0 comments on commit f57c265

Please sign in to comment.