diff --git a/src/Interpreters/FullSortingMergeJoin.h b/src/Interpreters/FullSortingMergeJoin.h index 7688d44f7a96..7e07c2004b63 100644 --- a/src/Interpreters/FullSortingMergeJoin.h +++ b/src/Interpreters/FullSortingMergeJoin.h @@ -21,9 +21,11 @@ namespace ErrorCodes class FullSortingMergeJoin : public IJoin { public: - explicit FullSortingMergeJoin(std::shared_ptr table_join_, const Block & right_sample_block_) + explicit FullSortingMergeJoin(std::shared_ptr table_join_, const Block & right_sample_block_, + int null_direction_ = 1) : table_join(table_join_) , right_sample_block(right_sample_block_) + , null_direction(null_direction_) { LOG_TRACE(getLogger("FullSortingMergeJoin"), "Will use full sorting merge join"); } @@ -31,6 +33,8 @@ class FullSortingMergeJoin : public IJoin std::string getName() const override { return "FullSortingMergeJoin"; } const TableJoin & getTableJoin() const override { return *table_join; } + int getNullDirection() const { return null_direction; } + bool addBlockToJoin(const Block & /* block */, bool /* check_limits */) override { throw Exception(ErrorCodes::LOGICAL_ERROR, "FullSortingMergeJoin::addBlockToJoin should not be called"); @@ -119,6 +123,7 @@ class FullSortingMergeJoin : public IJoin std::shared_ptr table_join; Block right_sample_block; Block totals; + int null_direction; }; } diff --git a/src/Processors/Transforms/MergeJoinTransform.cpp b/src/Processors/Transforms/MergeJoinTransform.cpp index 2d313d4ea5c3..62361bef5e2f 100644 --- a/src/Processors/Transforms/MergeJoinTransform.cpp +++ b/src/Processors/Transforms/MergeJoinTransform.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -43,7 +44,7 @@ FullMergeJoinCursorPtr createCursor(const Block & block, const Names & columns) } template -int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint = 1) +int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint) { if constexpr (has_left_nulls && has_right_nulls) { @@ -88,35 +89,36 @@ int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, } int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, size_t lpos, - const SortCursorImpl & rhs, size_t rpos) + const SortCursorImpl & rhs, size_t rpos, + int null_direction_hint) { for (size_t i = 0; i < lhs.sort_columns_size; ++i) { /// TODO(@vdimir): use nullableCompareAt only if there's nullable columns - int cmp = nullableCompareAt(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos); + int cmp = nullableCompareAt(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos, null_direction_hint); if (cmp != 0) return cmp; } return 0; } -int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs) +int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs, int null_direction_hint) { - return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow()); + return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow(), null_direction_hint); } -bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs) +bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint) { /// The last row of left cursor is less than the current row of the right cursor. - int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow()); + int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow(), null_direction_hint); return cmp < 0; } -int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs) +int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint) { - if (totallyLess(lhs, rhs)) + if (totallyLess(lhs, rhs, null_direction_hint)) return -1; - if (totallyLess(rhs, lhs)) + if (totallyLess(rhs, lhs, null_direction_hint)) return 1; return 0; } @@ -302,6 +304,13 @@ MergeJoinAlgorithm::MergeJoinAlgorithm( size_t right_idx = input_headers[1].getPositionByName(right_key); left_to_right_key_remap[left_idx] = right_idx; } + + const auto *smjPtr = typeid_cast(table_join.get()); + if (smjPtr) + { + null_direction_hint = smjPtr->getNullDirection(); + } + } void MergeJoinAlgorithm::logElapsed(double seconds) @@ -366,7 +375,8 @@ struct AllJoinImpl size_t max_block_size, PaddedPODArray & left_map, PaddedPODArray & right_map, - std::unique_ptr & state) + std::unique_ptr & state, + int null_direction_hint) { right_map.clear(); right_map.reserve(max_block_size); @@ -382,7 +392,7 @@ struct AllJoinImpl lpos = left_cursor->getRow(); rpos = right_cursor->getRow(); - cmp = compareCursors(left_cursor.cursor, right_cursor.cursor); + cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint); if (cmp == 0) { size_t lnum = nextDistinct(left_cursor.cursor); @@ -517,7 +527,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::allJoin(JoinKind kind) { PaddedPODArray idx_map[2]; - dispatchKind(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state); + dispatchKind(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state, null_direction_hint); assert(idx_map[0].size() == idx_map[1].size()); Chunk result; @@ -576,7 +586,8 @@ struct AnyJoinImpl FullMergeJoinCursor & right_cursor, PaddedPODArray & left_map, PaddedPODArray & right_map, - AnyJoinState & state) + AnyJoinState & state, + int null_direction_hint) { assert(enabled); @@ -599,7 +610,7 @@ struct AnyJoinImpl lpos = left_cursor->getRow(); rpos = right_cursor->getRow(); - cmp = compareCursors(left_cursor.cursor, right_cursor.cursor); + cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint); if (cmp == 0) { if constexpr (isLeftOrFull(kind)) @@ -723,7 +734,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::anyJoin(JoinKind kind) PaddedPODArray idx_map[2]; size_t prev_pos[] = {current_left.getRow(), current_right.getRow()}; - dispatchKind(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state); + dispatchKind(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state, null_direction_hint); assert(idx_map[0].empty() || idx_map[1].empty() || idx_map[0].size() == idx_map[1].size()); size_t num_result_rows = std::max(idx_map[0].size(), idx_map[1].size()); @@ -816,7 +827,7 @@ IMergingAlgorithm::Status MergeJoinAlgorithm::merge() } /// check if blocks are not intersecting at all - if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor); cmp != 0) + if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor, null_direction_hint); cmp != 0) { if (cmp < 0) { diff --git a/src/Processors/Transforms/MergeJoinTransform.h b/src/Processors/Transforms/MergeJoinTransform.h index 959550067f7a..cf9331abd592 100644 --- a/src/Processors/Transforms/MergeJoinTransform.h +++ b/src/Processors/Transforms/MergeJoinTransform.h @@ -258,6 +258,7 @@ class MergeJoinAlgorithm final : public IMergingAlgorithm JoinPtr table_join; size_t max_block_size; + int null_direction_hint = 1; struct Statistic {