From 6cae676fe65b7ce3b27fec81831f884b3c001e27 Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Sat, 26 Oct 2024 13:54:28 +0800 Subject: [PATCH] left semi bug fix --- velox/exec/MergeJoin.cpp | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 428c9b34caa8..718ebf3311f5 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -619,14 +619,15 @@ bool MergeJoin::addToOutputForLeftJoin() { // right->size()) << "\n"; } - if (isAntiJoin(joinType_) && filter_ && (outputSize_ + (rightEnd - rightStart) > outputBatchSize_)) { + if (isAntiJoin(joinType_) && filter_ && + (outputSize_ + (rightEnd - rightStart) > outputBatchSize_)) { // If we run out of space in the current output_, we will need to // produce a buffer and continue processing left later. In this case, // we cannot leave left as a lazy vector, since we cannot have two // dictionaries wrapping the same lazy vector. output_->resize(outputSize_); loadColumns(currentLeft_, *operatorCtx_->execCtx()); - + leftMatch_->setCursor(l, i); rightMatch_->setCursor(r, rightStart); return true; @@ -637,7 +638,8 @@ bool MergeJoin::addToOutputForLeftJoin() { if (isLeftSemiFilterJoin(joinType_) && filter_) { auto matchedRows = advanceFilterLeftSemiOutput(rightStart, j); if (matchedRows) { - if ((leftEnd - leftStart == 1) || (i + 1 >= leftEnd)) { + if ((numLefts == 1) && + ((leftEnd - leftStart == 1) || (i + 1 >= leftEnd))) { leftMatch_.reset(); rightMatch_.reset(); return true; @@ -663,15 +665,18 @@ bool MergeJoin::addToOutputForLeftJoin() { // input_->childAt(0)->asFlatVector()->valueAt(i) == // 1528291) { // std::cout << "before the outputSize_ is " << outputSize_ << "\n"; - // std::cout << "before the 1496551 output is " << output_->toString(outputSize_ - 1, outputSize_) << "\n"; + // std::cout << "before the 1496551 output is " << + // output_->toString(outputSize_ - 1, outputSize_) << "\n"; // } addOutputRow(left, i, right, j); - // std::cout << "the added output is " << output_->toString(outputSize_ - 1, outputSize_) << "\n"; - // if (input_ && + // std::cout << "the added output is " << + // output_->toString(outputSize_ - 1, outputSize_) << "\n"; if (input_ + // && // input_->childAt(0)->asFlatVector()->valueAt(i) == // 1528291) { // std::cout << "the outputSize_ is " << outputSize_ << "\n"; - // std::cout << "the 1496551 output is " << output_->toString(outputSize_ - 1, outputSize_) << "\n"; + // std::cout << "the 1496551 output is " << + // output_->toString(outputSize_ - 1, outputSize_) << "\n"; // } offsets.emplace_back(matchedNumRows); @@ -714,7 +719,7 @@ bool MergeJoin::addToOutputForLeftJoin() { } } else { // Only remain the one record for multi matched rows in right side. - for (auto i = 0; i < (offsets.size() - 1) ; ++i) { + for (auto i = 0; i < (offsets.size() - 1); ++i) { auto rowIndex = processedRowNums + offsets[i]; joinTracker_->addMiss(rowIndex, true); } @@ -909,7 +914,8 @@ RowVectorPtr MergeJoin::getOutput() { auto output = doGetOutput(); if (output != nullptr && output->size() > 0) { if (filter_) { - // std::cout << "before filter the output is " << output->toString(0, output->size()) << "\n"; + // std::cout << "before filter the output is " << output->toString(0, + // output->size()) << "\n"; output = applyFilter(output); // std::cout << "after filter the output is " // << output->toString(0, output->size()) << "\n"; @@ -919,9 +925,9 @@ RowVectorPtr MergeJoin::getOutput() { filterInput_->childAt(channel).reset(); } - std::cout << "the smj output is " << output->toString(0, output->size()) - << "\n"; - + std::cout << "the smj output is " + << output->toString(0, output->size()) << "\n"; + return output; } @@ -1411,7 +1417,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // if (output_ && // output_->childAt(0)->asFlatVector()->valueAt(i) == // 1543943) { - // std::cout << "the output is " << output_->toString(0, output_->size()) << "\n"; + // std::cout << "the output is " << output_->toString(0, + // output_->size()) << "\n"; // } if (filterRows.isValid(i)) {