diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 6639d1ca39e5..249910ecafe9 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -14,7 +14,6 @@ * limitations under the License. */ #include "velox/exec/MergeJoin.h" -#include #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" @@ -643,6 +642,37 @@ bool MergeJoin::addToOutputForLeftJoin() { break; } } + + if (isAntiJoin(joinType_) && filter_) { + auto numRows = (rightEnd - rightStart); + SelectivityVector matchingRows{outputSize_, false}; + matchingRows.setValidRange( + (outputSize_ - numRows), outputSize_, true); + matchingRows.updateBounds(); + + evaluateFilter(matchingRows); + + auto processedRowNums = (outputSize_ - numRows); + + auto matchedRow = false; + for (auto j = rightStart; j < rightEnd; ++j) { + auto rowIndex = processedRowNums + j - rightStart; + const bool passed = !decodedFilterResult_.isNullAt(rowIndex) && + decodedFilterResult_.valueAt(rowIndex); + if (passed) { + matchedRow = true; + } + } + + if (matchedRow) { + for (auto j = rightStart; j < rightEnd; ++j) { + auto rowIndex = processedRowNums + j - rightStart; + const bool passed = !decodedFilterResult_.isNullAt(rowIndex) && + decodedFilterResult_.valueAt(rowIndex); + joinTracker_->addMiss(rowIndex, true); + } + } + } } } } @@ -838,8 +868,6 @@ RowVectorPtr MergeJoin::getOutput() { for (const auto [channel, _] : filterInputToOutputChannel_) { filterInput_->childAt(channel).reset(); } - std::cout << "the output is " << output->toString(0, output->size()) - << "\n"; return output; } @@ -854,8 +882,8 @@ RowVectorPtr MergeJoin::getOutput() { // No rows survived the filter for anti join. Get more rows. continue; } else { - std::cout << "the output is " << output->toString(0, output->size()) - << "\n"; + // std::cout << "the output is " << output->toString(0, output->size()) + // << "\n"; return output; } } @@ -900,8 +928,6 @@ RowVectorPtr MergeJoin::getOutput() { continue; } - std::cout << "begin return nullptr for smj output" - << "\n"; return nullptr; } } @@ -1344,6 +1370,9 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { } } } else { + if (isAntiJoin(joinType_) && joinTracker_->multiMatchedRows(i)) { + continue; + } if (!isLeftSemiFilterJoin(joinType_) && !isRightSemiFilterJoin(joinType_)) { // This row doesn't have a match on the right side. Keep it diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index aaf8633a9ab5..506b1f2a0145 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -281,6 +281,10 @@ class MergeJoin : public Operator { : matchingRows_{numRows, false} { leftRowNumbers_ = AlignedBuffer::allocate(numRows, pool); rawLeftRowNumbers_ = leftRowNumbers_->asMutable(); + + leftMultiMatchedRows_ = + AlignedBuffer::allocate(numRows, pool); + rawleftMultiMatchedRows_ = leftMultiMatchedRows_->asMutable(); } /// Records a row of output that corresponds to a match between a left-side @@ -316,9 +320,10 @@ class MergeJoin : public Operator { /// row that has no match on the right-side. The caller must call addMatch /// or addMiss method for each row of output in order, starting with the /// first row. - void addMiss(vector_size_t outputIndex) { + void addMiss(vector_size_t outputIndex, bool multiMatched = false) { matchingRows_.setValid(outputIndex, false); resetLastVector(); + rawleftMultiMatchedRows_[outputIndex] = multiMatched; } /// Clear the left-side vector and index of the last added output row. The @@ -363,6 +368,10 @@ class MergeJoin : public Operator { return currentLeftRowNumber_ == rawLeftRowNumbers_[row]; } + bool multiMatchedRows(vector_size_t rowIndex) { + return rawleftMultiMatchedRows_[rowIndex]; + } + /// Called when all rows from the current output batch are processed and the /// next batch of output will start with a new left-side row or there will /// be no more batches. Calls 'onMiss' for the last left-side row if the @@ -396,6 +405,9 @@ class MergeJoin : public Operator { BufferPtr leftRowNumbers_; vector_size_t* rawLeftRowNumbers_; + BufferPtr leftMultiMatchedRows_; + bool* rawleftMultiMatchedRows_; + // Synthetic number assigned to the last added "match" row or zero if no row // has been added yet. vector_size_t lastLeftRowNumber_{0};