Skip to content

Commit

Permalink
chore(planner): fix merge predicates with different data types (datab…
Browse files Browse the repository at this point in the history
…endlabs#15725)

* chore(planner): fix merge predicates with different data types.

* chore(planner): add sqllogictest

* chore: remove uesless code
  • Loading branch information
Dousir9 authored Jun 4, 2024
1 parent 7cf7599 commit 3dde44c
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 14 deletions.
53 changes: 39 additions & 14 deletions src/query/sql/src/planner/optimizer/filter/infer_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ use std::collections::HashMap;
use std::collections::HashSet;

use databend_common_exception::Result;
use databend_common_expression::type_check::common_super_type;
use databend_common_expression::types::DataType;
use databend_common_expression::types::NumberDataType;
use databend_common_expression::types::NumberScalar;
use databend_common_expression::Scalar;
use databend_common_functions::BUILTIN_FUNCTIONS;
use ordered_float::OrderedFloat;

use crate::optimizer::rule::constant::check_float_range;
Expand Down Expand Up @@ -111,7 +113,7 @@ impl<'a> InferFilterOptimizer<'a> {
self.add_expr_predicate(&func.arguments[0], Predicate {
op,
constant,
});
})?;
} else {
remaining_predicates.push(predicate);
}
Expand All @@ -127,7 +129,7 @@ impl<'a> InferFilterOptimizer<'a> {
self.add_expr_predicate(&func.arguments[1], Predicate {
op: op.reverse(),
constant,
});
})?;
} else {
remaining_predicates.push(predicate);
}
Expand All @@ -145,7 +147,7 @@ impl<'a> InferFilterOptimizer<'a> {
let mut new_predicates = vec![];
if !self.is_falsy {
// Derive new predicates from existing predicates, `derive_predicates` may change is_falsy to true.
new_predicates = self.derive_predicates();
new_predicates = self.derive_predicates()?;
}

if self.is_falsy {
Expand Down Expand Up @@ -188,22 +190,22 @@ impl<'a> InferFilterOptimizer<'a> {
};
}

fn add_expr_predicate(&mut self, expr: &ScalarExpr, new_predicate: Predicate) {
fn add_expr_predicate(&mut self, expr: &ScalarExpr, new_predicate: Predicate) -> Result<()> {
match self.expr_index.get(expr) {
Some(index) => {
let predicates = &mut self.expr_predicates[*index];
for predicate in predicates.iter_mut() {
match Self::merge_predicate(predicate, &new_predicate) {
match Self::merge_predicate(predicate.clone(), new_predicate.clone())? {
MergeResult::None => {
self.is_falsy = true;
return;
return Ok(());
}
MergeResult::Left => {
return;
return Ok(());
}
MergeResult::Right => {
*predicate = new_predicate;
return;
return Ok(());
}
MergeResult::All => (),
}
Expand All @@ -214,10 +216,32 @@ impl<'a> InferFilterOptimizer<'a> {
self.add_expr(expr, vec![new_predicate], vec![]);
}
};
Ok(())
}

fn merge_predicate(left: &Predicate, right: &Predicate) -> MergeResult {
match left.op {
fn merge_predicate(mut left: Predicate, mut right: Predicate) -> Result<MergeResult> {
let left_data_type = ScalarExpr::ConstantExpr(left.constant.clone()).data_type()?;
let right_data_type = ScalarExpr::ConstantExpr(right.constant.clone()).data_type()?;
if left_data_type != right_data_type {
let common_data_type = common_super_type(
left_data_type,
right_data_type,
&BUILTIN_FUNCTIONS.default_cast_rules,
);
if let Some(data_type) = common_data_type {
let (left_is_adjusted, left_constant) =
adjust_scalar(left.constant.value.clone(), data_type.clone());
let (right_is_adjusted, right_constant) =
adjust_scalar(right.constant.value.clone(), data_type.clone());
if left_is_adjusted && right_is_adjusted {
left.constant = left_constant;
right.constant = right_constant;
}
} else {
return Ok(MergeResult::All);
}
}
let merge_result = match left.op {
ComparisonOp::Equal => match right.op {
ComparisonOp::Equal => match left.constant == right.constant {
true => MergeResult::Left,
Expand Down Expand Up @@ -358,7 +382,8 @@ impl<'a> InferFilterOptimizer<'a> {
false => MergeResult::Right,
},
},
}
};
Ok(merge_result)
}

fn find(parent: &mut [usize], x: usize) -> usize {
Expand All @@ -376,7 +401,7 @@ impl<'a> InferFilterOptimizer<'a> {
}
}

fn derive_predicates(&mut self) -> Vec<ScalarExpr> {
fn derive_predicates(&mut self) -> Result<Vec<ScalarExpr>> {
let mut result = vec![];
let num_exprs = self.exprs.len();

Expand Down Expand Up @@ -407,7 +432,7 @@ impl<'a> InferFilterOptimizer<'a> {
let expr = self.exprs[parent_index].clone();
let predicates = self.expr_predicates[index].clone();
for predicate in predicates {
self.add_expr_predicate(&expr, predicate);
self.add_expr_predicate(&expr, predicate)?;
}
}
}
Expand Down Expand Up @@ -455,7 +480,7 @@ impl<'a> InferFilterOptimizer<'a> {
}
}

result
Ok(result)
}

fn derive_remaining_predicates(&self, predicates: Vec<ScalarExpr>) -> Vec<ScalarExpr> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -910,3 +910,67 @@ drop table if exists t2;

statement ok
drop table if exists t3;

# merge predicates with different data types.
statement ok
drop table if exists t1;

statement ok
drop table if exists t2;

statement ok
create table t1(id BIGINT NOT NULL);

statement ok
create table t2(id BIGINT UNSIGNED NULL);

statement ok
insert into t1 values(869550529);

statement ok
insert into t2 values(869550529);

query T
explain SELECT * FROM t1 inner JOIN t2 on t1.id = t2.id where t2.id = 869550529;
----
HashJoin
├── output columns: [t1.id (#0), t2.id (#1)]
├── join type: INNER
├── build keys: [CAST(t2.id (#1) AS Int64 NULL)]
├── probe keys: [CAST(t1.id (#0) AS Int64 NULL)]
├── filters: []
├── estimated rows: 1.00
├── Filter(Build)
│ ├── output columns: [t2.id (#1)]
│ ├── filters: [is_true(t2.id (#1) = 869550529)]
│ ├── estimated rows: 1.00
│ └── TableScan
│ ├── table: default.default.t2
│ ├── output columns: [id (#1)]
│ ├── read rows: 1
│ ├── read size: < 1 KiB
│ ├── partitions total: 1
│ ├── partitions scanned: 1
│ ├── pruning stats: [segments: <range pruning: 1 to 1>, blocks: <range pruning: 1 to 1, bloom pruning: 1 to 1>]
│ ├── push downs: [filters: [is_true(t2.id (#1) = 869550529)], limit: NONE]
│ └── estimated rows: 1.00
└── Filter(Probe)
├── output columns: [t1.id (#0)]
├── filters: [t1.id (#0) = 869550529]
├── estimated rows: 1.00
└── TableScan
├── table: default.default.t1
├── output columns: [id (#0)]
├── read rows: 1
├── read size: < 1 KiB
├── partitions total: 1
├── partitions scanned: 1
├── pruning stats: [segments: <range pruning: 1 to 1>, blocks: <range pruning: 1 to 1, bloom pruning: 1 to 1>]
├── push downs: [filters: [t1.id (#0) = 869550529], limit: NONE]
└── estimated rows: 1.00

statement ok
drop table if exists t1;

statement ok
drop table if exists t2;
Original file line number Diff line number Diff line change
Expand Up @@ -729,3 +729,59 @@ drop table if exists t2;

statement ok
drop table if exists t3;

# merge predicates with different data types.
statement ok
drop table if exists t1;

statement ok
drop table if exists t2;

statement ok
create table t1(id BIGINT NOT NULL);

statement ok
create table t2(id BIGINT UNSIGNED NULL);

statement ok
insert into t1 values(869550529);

statement ok
insert into t2 values(869550529);

query T
explain SELECT * FROM t1 inner JOIN t2 on t1.id = t2.id where t2.id = 869550529;
----
HashJoin
├── output columns: [t1.id (#0), t2.id (#1)]
├── join type: INNER
├── build keys: [CAST(t2.id (#1) AS Int64 NULL)]
├── probe keys: [CAST(t1.id (#0) AS Int64 NULL)]
├── filters: []
├── estimated rows: 1.00
├── TableScan(Build)
│ ├── table: default.default.t2
│ ├── output columns: [id (#1)]
│ ├── read rows: 1
│ ├── read size: < 1 KiB
│ ├── partitions total: 1
│ ├── partitions scanned: 1
│ ├── pruning stats: [segments: <range pruning: 1 to 1>, blocks: <range pruning: 1 to 1, bloom pruning: 1 to 1>]
│ ├── push downs: [filters: [is_true(t2.id (#1) = 869550529)], limit: NONE]
│ └── estimated rows: 1.00
└── TableScan(Probe)
├── table: default.default.t1
├── output columns: [id (#0)]
├── read rows: 1
├── read size: < 1 KiB
├── partitions total: 1
├── partitions scanned: 1
├── pruning stats: [segments: <range pruning: 1 to 1>, blocks: <range pruning: 1 to 1, bloom pruning: 1 to 1>]
├── push downs: [filters: [t1.id (#0) = 869550529], limit: NONE]
└── estimated rows: 1.00

statement ok
drop table if exists t1;

statement ok
drop table if exists t2;

0 comments on commit 3dde44c

Please sign in to comment.