Skip to content

Commit

Permalink
Fix the anti join
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Oct 22, 2024
1 parent 9834e4b commit b27b04f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
43 changes: 36 additions & 7 deletions velox/exec/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*/
#include "velox/exec/MergeJoin.h"
#include <iostream>
#include "velox/exec/OperatorUtils.h"
#include "velox/exec/Task.h"
#include "velox/expression/FieldReference.h"
Expand Down Expand Up @@ -643,6 +642,37 @@ bool MergeJoin::addToOutputForLeftJoin() {
break;
}
}

if (isAntiJoin(joinType_) && filter_) {
auto numRows = (rightEnd - rightStart);
SelectivityVector matchingRows{outputSize_, false};
matchingRows.setValidRange(
(outputSize_ - numRows), outputSize_, true);
matchingRows.updateBounds();

evaluateFilter(matchingRows);

auto processedRowNums = (outputSize_ - numRows);

auto matchedRow = false;
for (auto j = rightStart; j < rightEnd; ++j) {
auto rowIndex = processedRowNums + j - rightStart;
const bool passed = !decodedFilterResult_.isNullAt(rowIndex) &&
decodedFilterResult_.valueAt<bool>(rowIndex);
if (passed) {
matchedRow = true;
}
}

if (matchedRow) {
for (auto j = rightStart; j < rightEnd; ++j) {
auto rowIndex = processedRowNums + j - rightStart;
const bool passed = !decodedFilterResult_.isNullAt(rowIndex) &&
decodedFilterResult_.valueAt<bool>(rowIndex);
joinTracker_->addMiss(rowIndex, true);
}
}
}
}
}
}
Expand Down Expand Up @@ -838,8 +868,6 @@ RowVectorPtr MergeJoin::getOutput() {
for (const auto [channel, _] : filterInputToOutputChannel_) {
filterInput_->childAt(channel).reset();
}
std::cout << "the output is " << output->toString(0, output->size())
<< "\n";
return output;
}

Expand All @@ -854,8 +882,8 @@ RowVectorPtr MergeJoin::getOutput() {
// No rows survived the filter for anti join. Get more rows.
continue;
} else {
std::cout << "the output is " << output->toString(0, output->size())
<< "\n";
// std::cout << "the output is " << output->toString(0, output->size())
// << "\n";
return output;
}
}
Expand Down Expand Up @@ -900,8 +928,6 @@ RowVectorPtr MergeJoin::getOutput() {
continue;
}

std::cout << "begin return nullptr for smj output"
<< "\n";
return nullptr;
}
}
Expand Down Expand Up @@ -1344,6 +1370,9 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) {
}
}
} else {
if (isAntiJoin(joinType_) && joinTracker_->multiMatchedRows(i)) {
continue;
}
if (!isLeftSemiFilterJoin(joinType_) &&
!isRightSemiFilterJoin(joinType_)) {
// This row doesn't have a match on the right side. Keep it
Expand Down
14 changes: 13 additions & 1 deletion velox/exec/MergeJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ class MergeJoin : public Operator {
: matchingRows_{numRows, false} {
leftRowNumbers_ = AlignedBuffer::allocate<vector_size_t>(numRows, pool);
rawLeftRowNumbers_ = leftRowNumbers_->asMutable<vector_size_t>();

leftMultiMatchedRows_ =
AlignedBuffer::allocate<vector_size_t>(numRows, pool);
rawleftMultiMatchedRows_ = leftMultiMatchedRows_->asMutable<bool>();
}

/// Records a row of output that corresponds to a match between a left-side
Expand Down Expand Up @@ -316,9 +320,10 @@ class MergeJoin : public Operator {
/// row that has no match on the right-side. The caller must call addMatch
/// or addMiss method for each row of output in order, starting with the
/// first row.
void addMiss(vector_size_t outputIndex) {
void addMiss(vector_size_t outputIndex, bool multiMatched = false) {
matchingRows_.setValid(outputIndex, false);
resetLastVector();
rawleftMultiMatchedRows_[outputIndex] = multiMatched;
}

/// Clear the left-side vector and index of the last added output row. The
Expand Down Expand Up @@ -363,6 +368,10 @@ class MergeJoin : public Operator {
return currentLeftRowNumber_ == rawLeftRowNumbers_[row];
}

bool multiMatchedRows(vector_size_t rowIndex) {
return rawleftMultiMatchedRows_[rowIndex];
}

/// Called when all rows from the current output batch are processed and the
/// next batch of output will start with a new left-side row or there will
/// be no more batches. Calls 'onMiss' for the last left-side row if the
Expand Down Expand Up @@ -396,6 +405,9 @@ class MergeJoin : public Operator {
BufferPtr leftRowNumbers_;
vector_size_t* rawLeftRowNumbers_;

BufferPtr leftMultiMatchedRows_;
bool* rawleftMultiMatchedRows_;

// Synthetic number assigned to the last added "match" row or zero if no row
// has been added yet.
vector_size_t lastLeftRowNumber_{0};
Expand Down

0 comments on commit b27b04f

Please sign in to comment.