Skip to content

Commit

Permalink
Merge pull request ClickHouse#60896 from loudongfeng/master_smj_nullo…
Browse files Browse the repository at this point in the history
…rder

make nulls direction configuable for FullSortingMergeJoin
  • Loading branch information
vdimir authored Mar 21, 2024
2 parents 5b0610b + 96e9043 commit 9b51780
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
7 changes: 6 additions & 1 deletion src/Interpreters/FullSortingMergeJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,20 @@ namespace ErrorCodes
class FullSortingMergeJoin : public IJoin
{
public:
explicit FullSortingMergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_)
explicit FullSortingMergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_,
int null_direction_ = 1)
: table_join(table_join_)
, right_sample_block(right_sample_block_)
, null_direction(null_direction_)
{
LOG_TRACE(getLogger("FullSortingMergeJoin"), "Will use full sorting merge join");
}

std::string getName() const override { return "FullSortingMergeJoin"; }
const TableJoin & getTableJoin() const override { return *table_join; }

int getNullDirection() const { return null_direction; }

bool addBlockToJoin(const Block & /* block */, bool /* check_limits */) override
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "FullSortingMergeJoin::addBlockToJoin should not be called");
Expand Down Expand Up @@ -119,6 +123,7 @@ class FullSortingMergeJoin : public IJoin
std::shared_ptr<TableJoin> table_join;
Block right_sample_block;
Block totals;
int null_direction;
};

}
45 changes: 28 additions & 17 deletions src/Processors/Transforms/MergeJoinTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <Core/SortCursor.h>
#include <Core/SortDescription.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/FullSortingMergeJoin.h>
#include <Interpreters/TableJoin.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Processors/Transforms/MergeJoinTransform.h>
Expand Down Expand Up @@ -43,7 +44,7 @@ FullMergeJoinCursorPtr createCursor(const Block & block, const Names & columns)
}

template <bool has_left_nulls, bool has_right_nulls>
int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint = 1)
int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint)
{
if constexpr (has_left_nulls && has_right_nulls)
{
Expand Down Expand Up @@ -88,35 +89,36 @@ int nullableCompareAt(const IColumn & left_column, const IColumn & right_column,
}

int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, size_t lpos,
const SortCursorImpl & rhs, size_t rpos)
const SortCursorImpl & rhs, size_t rpos,
int null_direction_hint)
{
for (size_t i = 0; i < lhs.sort_columns_size; ++i)
{
/// TODO(@vdimir): use nullableCompareAt only if there's nullable columns
int cmp = nullableCompareAt<true, true>(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos);
int cmp = nullableCompareAt<true, true>(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos, null_direction_hint);
if (cmp != 0)
return cmp;
}
return 0;
}

int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs)
int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs, int null_direction_hint)
{
return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow());
return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow(), null_direction_hint);
}

bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs)
bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint)
{
/// The last row of left cursor is less than the current row of the right cursor.
int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow());
int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow(), null_direction_hint);
return cmp < 0;
}

int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs)
int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint)
{
if (totallyLess(lhs, rhs))
if (totallyLess(lhs, rhs, null_direction_hint))
return -1;
if (totallyLess(rhs, lhs))
if (totallyLess(rhs, lhs, null_direction_hint))
return 1;
return 0;
}
Expand Down Expand Up @@ -302,6 +304,13 @@ MergeJoinAlgorithm::MergeJoinAlgorithm(
size_t right_idx = input_headers[1].getPositionByName(right_key);
left_to_right_key_remap[left_idx] = right_idx;
}

const auto *smjPtr = typeid_cast<const FullSortingMergeJoin *>(table_join.get());
if (smjPtr)
{
null_direction_hint = smjPtr->getNullDirection();
}

}

void MergeJoinAlgorithm::logElapsed(double seconds)
Expand Down Expand Up @@ -366,7 +375,8 @@ struct AllJoinImpl
size_t max_block_size,
PaddedPODArray<UInt64> & left_map,
PaddedPODArray<UInt64> & right_map,
std::unique_ptr<AllJoinState> & state)
std::unique_ptr<AllJoinState> & state,
int null_direction_hint)
{
right_map.clear();
right_map.reserve(max_block_size);
Expand All @@ -382,7 +392,7 @@ struct AllJoinImpl
lpos = left_cursor->getRow();
rpos = right_cursor->getRow();

cmp = compareCursors(left_cursor.cursor, right_cursor.cursor);
cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint);
if (cmp == 0)
{
size_t lnum = nextDistinct(left_cursor.cursor);
Expand Down Expand Up @@ -517,7 +527,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::allJoin(JoinKind kind)
{
PaddedPODArray<UInt64> idx_map[2];

dispatchKind<AllJoinImpl>(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state);
dispatchKind<AllJoinImpl>(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state, null_direction_hint);
assert(idx_map[0].size() == idx_map[1].size());

Chunk result;
Expand Down Expand Up @@ -576,7 +586,8 @@ struct AnyJoinImpl
FullMergeJoinCursor & right_cursor,
PaddedPODArray<UInt64> & left_map,
PaddedPODArray<UInt64> & right_map,
AnyJoinState & state)
AnyJoinState & state,
int null_direction_hint)
{
assert(enabled);

Expand All @@ -599,7 +610,7 @@ struct AnyJoinImpl
lpos = left_cursor->getRow();
rpos = right_cursor->getRow();

cmp = compareCursors(left_cursor.cursor, right_cursor.cursor);
cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint);
if (cmp == 0)
{
if constexpr (isLeftOrFull(kind))
Expand Down Expand Up @@ -723,7 +734,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::anyJoin(JoinKind kind)
PaddedPODArray<UInt64> idx_map[2];
size_t prev_pos[] = {current_left.getRow(), current_right.getRow()};

dispatchKind<AnyJoinImpl>(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state);
dispatchKind<AnyJoinImpl>(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state, null_direction_hint);

assert(idx_map[0].empty() || idx_map[1].empty() || idx_map[0].size() == idx_map[1].size());
size_t num_result_rows = std::max(idx_map[0].size(), idx_map[1].size());
Expand Down Expand Up @@ -816,7 +827,7 @@ IMergingAlgorithm::Status MergeJoinAlgorithm::merge()
}

/// check if blocks are not intersecting at all
if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor); cmp != 0)
if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor, null_direction_hint); cmp != 0)
{
if (cmp < 0)
{
Expand Down
1 change: 1 addition & 0 deletions src/Processors/Transforms/MergeJoinTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class MergeJoinAlgorithm final : public IMergingAlgorithm
JoinPtr table_join;

size_t max_block_size;
int null_direction_hint = 1;

struct Statistic
{
Expand Down

0 comments on commit 9b51780

Please sign in to comment.