Skip to content

Commit

Permalink
[BugFix] Broadcast Join should not generate nondetermistic GRF (#44111)
Browse files Browse the repository at this point in the history
Signed-off-by: satanson <[email protected]>
(cherry picked from commit ecbc790)

# Conflicts:
#	be/src/exec/exec_node.h
#	be/src/exec/vectorized/hash_join_node.cpp
#	fe/fe-core/src/main/java/com/starrocks/planner/JoinNode.java
#	fe/fe-core/src/main/java/com/starrocks/planner/PlanFragment.java
#	fe/fe-core/src/main/java/com/starrocks/planner/RuntimeFilterDescription.java
#	fe/fe-core/src/main/java/com/starrocks/planner/RuntimeFilterPushDownContext.java
  • Loading branch information
satanson authored and mergify[bot] committed May 7, 2024
1 parent 795be77 commit c7f3558
Show file tree
Hide file tree
Showing 22 changed files with 1,743 additions and 8 deletions.
2 changes: 1 addition & 1 deletion be/src/exec/exec_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void ExecNode::push_down_join_runtime_filter(RuntimeState* state, vectorized::Ru
if (_type != TPlanNodeType::AGGREGATION_NODE && _type != TPlanNodeType::ANALYTIC_EVAL_NODE) {
push_down_join_runtime_filter_to_children(state, collector);
}
_runtime_filter_collector.push_down(collector, _tuple_ids, _local_rf_waiting_set);
_runtime_filter_collector.push_down(state, id(), collector, _tuple_ids, _local_rf_waiting_set);
}

void ExecNode::push_down_join_runtime_filter_to_children(RuntimeState* state,
Expand Down
12 changes: 12 additions & 0 deletions be/src/exec/exec_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ class ExecNode {
// Names of counters shared by all exec nodes
static const std::string ROW_THROUGHPUT_COUNTER;

<<<<<<< HEAD
=======
static void may_add_chunk_accumulate_operator(OpFactories& ops, pipeline::PipelineBuilderContext* context, int id);

void set_children(std::vector<ExecNode*>&& children) { _children = std::move(children); }

const std::vector<ExecNode*>& children() const { return _children; }

[[nodiscard]] static Status create_vectorized_node(RuntimeState* state, ObjectPool* pool, const TPlanNode& tnode,
const DescriptorTbl& descs, ExecNode** node);

>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111))
protected:
friend class DataSink;

Expand Down
41 changes: 41 additions & 0 deletions be/src/exec/pipeline/fragment_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,41 @@ int FragmentExecutor::_calc_query_expired_seconds(const UnifiedExecPlanFragmentP
return QueryContext::DEFAULT_EXPIRE_SECONDS;
}

static void collect_shuffle_hash_bucket_rf_ids(const ExecNode* node, std::unordered_set<int32_t>& filter_ids) {
for (const auto* child : node->children()) {
collect_shuffle_hash_bucket_rf_ids(child, filter_ids);
}
if (node->type() == TPlanNodeType::HASH_JOIN_NODE) {
const auto* join_node = down_cast<const HashJoinNode*>(node);
if (join_node->distribution_mode() == TJoinDistributionMode::SHUFFLE_HASH_BUCKET) {
for (const auto* rf : join_node->build_runtime_filters()) {
filter_ids.insert(rf->filter_id());
}
}
}
}

static std::unordered_set<int32_t> collect_broadcast_join_right_offsprings(
const ExecNode* node, BroadcastJoinRightOffsprings& broadcast_join_right_offsprings) {
std::vector<std::unordered_set<int32_t>> offsprings_per_child;
std::unordered_set<int32_t> offsprings;
offsprings_per_child.reserve(node->children().size());
for (const auto* child : node->children()) {
auto child_offspring = collect_broadcast_join_right_offsprings(child, broadcast_join_right_offsprings);
offsprings.insert(child_offspring.begin(), child_offspring.end());
offsprings_per_child.push_back(std::move(child_offspring));
}
offsprings.insert(node->id());
if (node->type() == TPlanNodeType::HASH_JOIN_NODE) {
const auto* join_node = down_cast<const HashJoinNode*>(node);
if (join_node->distribution_mode() == TJoinDistributionMode::BROADCAST &&
join_node->can_generate_global_runtime_filter()) {
broadcast_join_right_offsprings.insert(offsprings_per_child[1].begin(), offsprings_per_child[1].end());
}
}
return offsprings;
}

Status FragmentExecutor::_prepare_exec_plan(ExecEnv* exec_env, const UnifiedExecPlanFragmentParams& request) {
auto* runtime_state = _fragment_ctx->runtime_state();
auto* obj_pool = runtime_state->obj_pool();
Expand All @@ -307,6 +342,12 @@ Status FragmentExecutor::_prepare_exec_plan(ExecEnv* exec_env, const UnifiedExec
// Set up plan
RETURN_IF_ERROR(ExecNode::create_tree(runtime_state, obj_pool, fragment.plan, desc_tbl, &_fragment_ctx->plan()));
ExecNode* plan = _fragment_ctx->plan();
std::unordered_set<int32_t> filter_ids;
collect_shuffle_hash_bucket_rf_ids(plan, filter_ids);
runtime_state->set_shuffle_hash_bucket_rf_ids(std::move(filter_ids));
BroadcastJoinRightOffsprings broadcast_join_right_offsprings_map;
collect_broadcast_join_right_offsprings(plan, broadcast_join_right_offsprings_map);
runtime_state->set_broadcast_join_right_offsprings(std::move(broadcast_join_right_offsprings_map));
plan->push_down_join_runtime_filter_recursively(runtime_state);
std::vector<TupleSlotMapping> empty_mappings;
plan->push_down_tuple_slot_mappings(runtime_state, empty_mappings);
Expand Down
2 changes: 1 addition & 1 deletion be/src/exec/vectorized/aggregate/aggregate_base_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Status AggregateBaseNode::close(RuntimeState* state) {
void AggregateBaseNode::push_down_join_runtime_filter(RuntimeState* state,
vectorized::RuntimeFilterProbeCollector* collector) {
// accept runtime filters from parent if possible.
_runtime_filter_collector.push_down(collector, _tuple_ids, _local_rf_waiting_set);
_runtime_filter_collector.push_down(state, id(), collector, _tuple_ids, _local_rf_waiting_set);

// check to see if runtime filters can be rewritten
auto& descriptors = _runtime_filter_collector.descriptors();
Expand Down
27 changes: 27 additions & 0 deletions be/src/exec/vectorized/hash_join_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,4 +942,31 @@ Status HashJoinNode::_create_implicit_local_join_runtime_filters(RuntimeState* s
return Status::OK();
}

<<<<<<< HEAD:be/src/exec/vectorized/hash_join_node.cpp
} // namespace starrocks::vectorized
=======
bool HashJoinNode::can_generate_global_runtime_filter() const {
return std::any_of(_build_runtime_filters.begin(), _build_runtime_filters.end(),
[](const RuntimeFilterBuildDescriptor* rf) { return rf->has_remote_targets(); });
}

void HashJoinNode::push_down_join_runtime_filter(RuntimeState* state, RuntimeFilterProbeCollector* collector) {
if (collector->empty()) return;
if (_join_type == TJoinOp::INNER_JOIN || _join_type == TJoinOp::LEFT_SEMI_JOIN ||
_join_type == TJoinOp::RIGHT_SEMI_JOIN) {
ExecNode::push_down_join_runtime_filter(state, collector);
return;
}
_runtime_filter_collector.push_down(state, id(), collector, _tuple_ids, _local_rf_waiting_set);
}

TJoinDistributionMode::type HashJoinNode::distribution_mode() const {
return _distribution_mode;
}

const std::list<RuntimeFilterBuildDescriptor*>& HashJoinNode::build_runtime_filters() const {
return _build_runtime_filters;
}

} // namespace starrocks
>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111)):be/src/exec/hash_join_node.cpp
4 changes: 4 additions & 0 deletions be/src/exec/vectorized/hash_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class HashJoinNode final : public ExecNode {
Status get_next(RuntimeState* state, ChunkPtr* chunk, bool* eos) override;
Status close(RuntimeState* state) override;
pipeline::OpFactories decompose_to_pipeline(pipeline::PipelineBuilderContext* context) override;
bool can_generate_global_runtime_filter() const;
TJoinDistributionMode::type distribution_mode() const;
const std::list<RuntimeFilterBuildDescriptor*>& build_runtime_filters() const;
void push_down_join_runtime_filter(RuntimeState* state, RuntimeFilterProbeCollector* collector) override;

private:
static bool _has_null(const ColumnPtr& column);
Expand Down
2 changes: 1 addition & 1 deletion be/src/exec/vectorized/project_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ void ProjectNode::push_down_tuple_slot_mappings(RuntimeState* state,
void ProjectNode::push_down_join_runtime_filter(RuntimeState* state,
vectorized::RuntimeFilterProbeCollector* collector) {
// accept runtime filters from parent if possible.
_runtime_filter_collector.push_down(collector, _tuple_ids, _local_rf_waiting_set);
_runtime_filter_collector.push_down(state, id(), collector, _tuple_ids, _local_rf_waiting_set);

// check to see if runtime filters can be rewritten
auto& descriptors = _runtime_filter_collector.descriptors();
Expand Down
6 changes: 4 additions & 2 deletions be/src/exprs/vectorized/runtime_filter_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ void RuntimeFilterProbeCollector::update_selectivity(vectorized::Chunk* chunk,
}
}

void RuntimeFilterProbeCollector::push_down(RuntimeFilterProbeCollector* parent, const std::vector<TupleId>& tuple_ids,
void RuntimeFilterProbeCollector::push_down(const RuntimeState* state, TPlanNodeId target_plan_node_id,
RuntimeFilterProbeCollector* parent, const std::vector<TupleId>& tuple_ids,
std::set<TPlanNodeId>& local_rf_waiting_set) {
if (this == parent) return;
auto iter = parent->_descriptors.begin();
Expand All @@ -524,7 +525,8 @@ void RuntimeFilterProbeCollector::push_down(RuntimeFilterProbeCollector* parent,
++iter;
continue;
}
if (desc->is_bound(tuple_ids)) {
if (desc->is_bound(tuple_ids) && !(state->broadcast_join_right_offsprings().contains(target_plan_node_id) &&
state->shuffle_hash_bucket_rf_ids().contains(desc->filter_id()))) {
add_descriptor(desc);
if (desc->is_local()) {
local_rf_waiting_set.insert(desc->build_plan_node_id());
Expand Down
4 changes: 2 additions & 2 deletions be/src/exprs/vectorized/runtime_filter_bank.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ class RuntimeFilterProbeCollector {
void add_descriptor(RuntimeFilterProbeDescriptor* desc);
// accept RuntimeFilterCollector from parent node
// which means parent node to push down runtime filter.
void push_down(RuntimeFilterProbeCollector* parent, const std::vector<TupleId>& tuple_ids,
std::set<TPlanNodeId>& rf_waiting_set);
void push_down(const RuntimeState* state, TPlanNodeId target_plan_node_id, RuntimeFilterProbeCollector* parent,
const std::vector<TupleId>& tuple_ids, std::set<TPlanNodeId>& rf_waiting_set);
std::map<int32_t, RuntimeFilterProbeDescriptor*>& descriptors() { return _descriptors; }
const std::map<int32_t, RuntimeFilterProbeDescriptor*>& descriptors() const { return _descriptors; }

Expand Down
3 changes: 3 additions & 0 deletions be/src/runtime/runtime_filter_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ void RuntimeFilterPort::publish_runtime_filters(std::list<vectorized::RuntimeFil
auto* filter = rf_desc->runtime_filter();

if (filter == nullptr || !rf_desc->has_remote_targets()) continue;
// Empty runtime filter generated by broadcast join can not be used as a global runtime, because it
// maybe shirt-circuited by empty probe side.
if (rf_desc->join_mode() == TRuntimeFilterBuildJoinMode::BORADCAST && filter->size() == 0) continue;

auto directly_send_broadcast_grf = rf_desc->join_mode() == TRuntimeFilterBuildJoinMode::BORADCAST &&
!rf_desc->broadcast_grf_senders().empty();
Expand Down
19 changes: 18 additions & 1 deletion be/src/runtime/runtime_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RowDescriptor;
class RuntimeFilterPort;
class QueryStatistics;
class QueryStatisticsRecvr;

using BroadcastJoinRightOffsprings = std::unordered_set<int32_t>;
namespace pipeline {
class QueryContext;
}
Expand Down Expand Up @@ -345,6 +345,20 @@ class RuntimeState {

bool use_page_cache();

void set_shuffle_hash_bucket_rf_ids(std::unordered_set<int32_t>&& filter_ids) {
this->_shuffle_hash_bucket_rf_ids = std::move(filter_ids);
}

const std::unordered_set<int32_t>& shuffle_hash_bucket_rf_ids() const { return this->_shuffle_hash_bucket_rf_ids; }

void set_broadcast_join_right_offsprings(BroadcastJoinRightOffsprings&& broadcast_join_right_offsprings) {
this->_broadcast_join_right_offsprings = std::move(broadcast_join_right_offsprings);
}

const BroadcastJoinRightOffsprings& broadcast_join_right_offsprings() const {
return this->_broadcast_join_right_offsprings;
}

private:
// Set per-query state.
void _init(const TUniqueId& fragment_instance_id, const TQueryOptions& query_options,
Expand Down Expand Up @@ -464,6 +478,9 @@ class RuntimeState {
pipeline::FragmentContext* _fragment_ctx = nullptr;

bool _enable_pipeline_engine = false;

std::unordered_set<int32_t> _shuffle_hash_bucket_rf_ids;
BroadcastJoinRightOffsprings _broadcast_join_right_offsprings;
};

#define LIMIT_EXCEEDED(tracker, state, msg) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ public boolean isLeftAntiJoin() {
return this == LEFT_ANTI_JOIN || this == NULL_AWARE_LEFT_ANTI_JOIN;
}

public boolean isNullAwareLeftAntiJoin() {
return this == NULL_AWARE_LEFT_ANTI_JOIN;
}

public boolean isRightSemiJoin() {
return this == RIGHT_SEMI_JOIN;
}
Expand Down Expand Up @@ -140,6 +144,10 @@ public static Set<JoinOperator> semiAntiJoinSet() {
public static Set<JoinOperator> innerCrossJoinSet() {
return Sets.newHashSet(INNER_JOIN, CROSS_JOIN);
}

public boolean canGenerateRuntimeFilter() {
return !(isLeftOuterJoin() || isFullOuterJoin() || isLeftAntiJoin());
}
}


86 changes: 86 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/planner/JoinNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.starrocks.analysis.Analyzer;
import com.starrocks.analysis.BinaryPredicate;
<<<<<<< HEAD
=======
import com.starrocks.analysis.BinaryType;
import com.starrocks.analysis.DescriptorTable;
>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111))
import com.starrocks.analysis.Expr;
import com.starrocks.analysis.JoinOperator;
import com.starrocks.analysis.SlotId;
Expand Down Expand Up @@ -158,7 +164,12 @@ public List<Expr> getProbePartitionByExprs() {
}

@Override
<<<<<<< HEAD
public void buildRuntimeFilters(IdGenerator<RuntimeFilterId> runtimeFilterIdIdGenerator) {
=======
public void buildRuntimeFilters(IdGenerator<RuntimeFilterId> runtimeFilterIdIdGenerator, DescriptorTable descTbl,
ExecGroupSets execGroupSets) {
>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111))
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
JoinOperator joinOp = getJoinOp();
PlanNode inner = getChild(1);
Expand Down Expand Up @@ -271,13 +282,79 @@ public boolean pushDownRuntimeFiltersForChild(RuntimeFilterDescription descripti
partitionByExprs, candidatesOfSlotExprsForChild(partitionByExprs, childIdx), childIdx, false);
}

private Optional<Boolean> pushDownRuntimeFilterBilaterally(RuntimeFilterPushDownContext context,
Expr probeExpr,
List<Expr> partitionByExprs) {
if (joinOp.isCrossJoin() || joinOp.isNullAwareLeftAntiJoin() || eqJoinConjuncts.isEmpty()) {
return Optional.empty();
}

if (!(probeExpr instanceof SlotRef)) {
return Optional.empty();
}
SlotRef probeSlotRefExpr = probeExpr.cast();
int slotId = probeSlotRefExpr.getSlotId().asInt();
boolean probeExprIsNotJoinColumn = eqJoinConjuncts.stream()
.filter(conj -> conj.getOp().equals(BinaryType.EQ))
.noneMatch(conj -> conj.getUsedSlotIds().contains(slotId));

if (probeExprIsNotJoinColumn) {
return Optional.empty();
}

// for join types except null-aware-left-anti-join and cross join
// runtime-filer probe expr uses join column, it can always be push down to both side of the join.
boolean hasPushedDown = pushDownRuntimeFiltersForChild(context, probeExpr, partitionByExprs, 0);
hasPushedDown |= pushDownRuntimeFiltersForChild(context, probeExpr, partitionByExprs, 1);
return Optional.of(hasPushedDown);
}


private Optional<Boolean> pushDownRuntimeFilterUnilaterally(RuntimeFilterPushDownContext context,
Expr probeExpr,
List<Expr> partitionByExprs) {
List<Integer> sides = ImmutableList.of();
if (joinOp.isLeftAntiJoin() || joinOp.isLeftOuterJoin()) {
sides = ImmutableList.of(0);
} else if (joinOp.isRightAntiJoin() || joinOp.isRightOuterJoin()) {
sides = ImmutableList.of(1);
} else if (joinOp.isInnerJoin() || joinOp.isSemiJoin() || joinOp.isCrossJoin()) {
sides = ImmutableList.of(0, 1);
}

boolean result = false;
Optional<List<List<Expr>>> optCandidatePartitionByExprs =
canPushDownRuntimeFilterCrossExchange(partitionByExprs);
if (optCandidatePartitionByExprs.isEmpty()) {
return Optional.of(false);
}
List<List<Expr>> candidatePartitionByExprs = optCandidatePartitionByExprs.get();
for (Integer side : sides) {
if (candidatePartitionByExprs.isEmpty()) {
result = getChild(side).pushDownRuntimeFilters(context, probeExpr, Lists.newArrayList());
} else {
for (List<Expr> partByExprs : candidatePartitionByExprs) {
result = getChild(side).pushDownRuntimeFilters(context, probeExpr, partByExprs);
if (result) {
break;
}
}
}
if (result) {
break;
}
}
return Optional.of(result);
}

@Override
public boolean pushDownRuntimeFilters(RuntimeFilterDescription description, Expr probeExpr, List<Expr> partitionByExprs) {
if (!canPushDownRuntimeFilter()) {
return false;
}

if (probeExpr.isBoundByTupleIds(getTupleIds())) {
<<<<<<< HEAD
boolean hasPushedDown = false;
// If probeExpr is SlotRef(a) and an equalJoinConjunct SlotRef(a)=SlotRef(b) exists in SemiJoin
// or InnerJoin, then the rf also can be pushed down to both sides of HashJoin because SlotRef(a) and
Expand All @@ -290,6 +367,15 @@ public boolean pushDownRuntimeFilters(RuntimeFilterDescription description, Expr
// fall back to PlanNode.pushDownRuntimeFilters for HJ if rf cannot be pushed down via equivalent
// equalJoinConjuncts
if (hasPushedDown || super.pushDownRuntimeFilters(description, probeExpr, partitionByExprs)) {
=======

Optional<Boolean> pushDownResult = pushDownRuntimeFilterBilaterally(context, probeExpr, partitionByExprs);
if (pushDownResult.isEmpty()) {
pushDownResult = pushDownRuntimeFilterUnilaterally(context, probeExpr, partitionByExprs);
}

if (pushDownResult.isPresent() && pushDownResult.get()) {
>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111))
return true;
}

Expand Down
Loading

0 comments on commit c7f3558

Please sign in to comment.