Skip to content

Commit

Permalink
Support filter in cross join elimination (apache#13025)
Browse files Browse the repository at this point in the history
* Support filter in cross join elimination

* Support filter in cross join elimination

* Support filter in cross join elimination

* Support filter in cross join elimination
  • Loading branch information
Dandandan authored Oct 22, 2024
1 parent 34fbe8e commit b978cf8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 25 deletions.
61 changes: 37 additions & 24 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ use crate::{OptimizerConfig, OptimizerRule};

use crate::join_key_set::JoinKeySet;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, Result};
use datafusion_common::Result;
use datafusion_expr::expr::{BinaryExpr, Expr};
use datafusion_expr::logical_plan::{
Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
};
use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
use datafusion_expr::{build_join_schema, ExprSchemable, Operator};
use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator};

#[derive(Default, Debug)]
pub struct EliminateCrossJoin;
Expand Down Expand Up @@ -88,6 +88,7 @@ impl OptimizerRule for EliminateCrossJoin {
let plan_schema = Arc::clone(plan.schema());
let mut possible_join_keys = JoinKeySet::new();
let mut all_inputs: Vec<LogicalPlan> = vec![];
let mut all_filters: Vec<Expr> = vec![];

let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
// if input isn't a join that can potentially be rewritten
Expand Down Expand Up @@ -116,6 +117,7 @@ impl OptimizerRule for EliminateCrossJoin {
Arc::unwrap_or_clone(input),
&mut possible_join_keys,
&mut all_inputs,
&mut all_filters,
)?;

extract_possible_join_keys(&predicate, &mut possible_join_keys);
Expand All @@ -130,7 +132,12 @@ impl OptimizerRule for EliminateCrossJoin {
if !can_flatten_join_inputs(&plan) {
return Ok(Transformed::no(plan));
}
flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?;
flatten_join_inputs(
plan,
&mut possible_join_keys,
&mut all_inputs,
&mut all_filters,
)?;
None
} else {
// recursively try to rewrite children
Expand Down Expand Up @@ -158,6 +165,13 @@ impl OptimizerRule for EliminateCrossJoin {
));
}

if !all_filters.is_empty() {
// Add any filters on top - PushDownFilter can push filters down to applicable join
let first = all_filters.swap_remove(0);
let predicate = all_filters.into_iter().fold(first, and);
left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?);
}

let Some(predicate) = parent_predicate else {
return Ok(Transformed::yes(left));
};
Expand Down Expand Up @@ -206,37 +220,39 @@ fn flatten_join_inputs(
plan: LogicalPlan,
possible_join_keys: &mut JoinKeySet,
all_inputs: &mut Vec<LogicalPlan>,
all_filters: &mut Vec<Expr>,
) -> Result<()> {
match plan {
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
// checked in can_flatten_join_inputs
if join.filter.is_some() {
return internal_err!(
"should not have filter in inner join in flatten_join_inputs"
);
if let Some(filter) = join.filter {
all_filters.push(filter);
}
possible_join_keys.insert_all_owned(join.on);
flatten_join_inputs(
Arc::unwrap_or_clone(join.left),
possible_join_keys,
all_inputs,
all_filters,
)?;
flatten_join_inputs(
Arc::unwrap_or_clone(join.right),
possible_join_keys,
all_inputs,
all_filters,
)?;
}
LogicalPlan::CrossJoin(join) => {
flatten_join_inputs(
Arc::unwrap_or_clone(join.left),
possible_join_keys,
all_inputs,
all_filters,
)?;
flatten_join_inputs(
Arc::unwrap_or_clone(join.right),
possible_join_keys,
all_inputs,
all_filters,
)?;
}
_ => {
Expand All @@ -253,13 +269,7 @@ fn flatten_join_inputs(
fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
// can only flatten inner / cross joins
match plan {
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
// The filter of inner join will lost, skip this rule.
// issue: https://github.com/apache/datafusion/issues/4844
if join.filter.is_some() {
return false;
}
}
LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
LogicalPlan::CrossJoin(_) => {}
_ => return false,
};
Expand Down Expand Up @@ -467,12 +477,6 @@ mod tests {
assert_eq!(&starting_schema, optimized_plan.schema())
}

fn assert_optimization_rule_fails(plan: LogicalPlan) {
let rule = EliminateCrossJoin::new();
let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap();
assert!(!transformed_plan.transformed)
}

#[test]
fn eliminate_cross_with_simple_and() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
Expand Down Expand Up @@ -642,8 +646,7 @@ mod tests {
}

#[test]
/// See https://github.com/apache/datafusion/issues/7530
fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> Result<()> {
fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;
let t3 = test_table_scan_with_name("t3")?;
Expand All @@ -660,7 +663,17 @@ mod tests {
.filter(col("t1.a").gt(lit(15u32)))?
.build()?;

assert_optimization_rule_fails(plan);
let expected = vec![
"Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"
];

assert_optimized_plan_eq(plan, expected);

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ logical_plan
01)Projection: t1.v0, t1.v1, t5.v2, t5.v3, t5.v4, t0.v0, t0.v1
02)--Inner Join: CAST(t1.v0 AS Float64) = t0.v1 Filter: t0.v1 + CAST(t5.v0 AS Float64) > Float64(0)
03)----Projection: t1.v0, t1.v1, t5.v0, t5.v2, t5.v3, t5.v4
04)------Inner Join: Using t1.v0 = t5.v0, t1.v1 = t5.v1
04)------Inner Join: t1.v0 = t5.v0, t1.v1 = t5.v1
05)--------TableScan: t1 projection=[v0, v1]
06)--------TableScan: t5 projection=[v0, v1, v2, v3, v4]
07)----TableScan: t0 projection=[v0, v1]
Expand Down

0 comments on commit b978cf8

Please sign in to comment.