diff --git a/src/query/sql/src/planner/optimizer/filter/infer_filter.rs b/src/query/sql/src/planner/optimizer/filter/infer_filter.rs index 8b575dff668c..0d6bcf048cef 100644 --- a/src/query/sql/src/planner/optimizer/filter/infer_filter.rs +++ b/src/query/sql/src/planner/optimizer/filter/infer_filter.rs @@ -16,10 +16,12 @@ use std::collections::HashMap; use std::collections::HashSet; use databend_common_exception::Result; +use databend_common_expression::type_check::common_super_type; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberScalar; use databend_common_expression::Scalar; +use databend_common_functions::BUILTIN_FUNCTIONS; use ordered_float::OrderedFloat; use crate::optimizer::rule::constant::check_float_range; @@ -111,7 +113,7 @@ impl<'a> InferFilterOptimizer<'a> { self.add_expr_predicate(&func.arguments[0], Predicate { op, constant, - }); + })?; } else { remaining_predicates.push(predicate); } @@ -127,7 +129,7 @@ impl<'a> InferFilterOptimizer<'a> { self.add_expr_predicate(&func.arguments[1], Predicate { op: op.reverse(), constant, - }); + })?; } else { remaining_predicates.push(predicate); } @@ -145,7 +147,7 @@ impl<'a> InferFilterOptimizer<'a> { let mut new_predicates = vec![]; if !self.is_falsy { // Derive new predicates from existing predicates, `derive_predicates` may change is_falsy to true. - new_predicates = self.derive_predicates(); + new_predicates = self.derive_predicates()?; } if self.is_falsy { @@ -188,22 +190,22 @@ impl<'a> InferFilterOptimizer<'a> { }; } - fn add_expr_predicate(&mut self, expr: &ScalarExpr, new_predicate: Predicate) { + fn add_expr_predicate(&mut self, expr: &ScalarExpr, new_predicate: Predicate) -> Result<()> { match self.expr_index.get(expr) { Some(index) => { let predicates = &mut self.expr_predicates[*index]; for predicate in predicates.iter_mut() { - match Self::merge_predicate(predicate, &new_predicate) { + match Self::merge_predicate(predicate.clone(), new_predicate.clone())? { MergeResult::None => { self.is_falsy = true; - return; + return Ok(()); } MergeResult::Left => { - return; + return Ok(()); } MergeResult::Right => { *predicate = new_predicate; - return; + return Ok(()); } MergeResult::All => (), } @@ -214,10 +216,32 @@ impl<'a> InferFilterOptimizer<'a> { self.add_expr(expr, vec![new_predicate], vec![]); } }; + Ok(()) } - fn merge_predicate(left: &Predicate, right: &Predicate) -> MergeResult { - match left.op { + fn merge_predicate(mut left: Predicate, mut right: Predicate) -> Result { + let left_data_type = ScalarExpr::ConstantExpr(left.constant.clone()).data_type()?; + let right_data_type = ScalarExpr::ConstantExpr(right.constant.clone()).data_type()?; + if left_data_type != right_data_type { + let common_data_type = common_super_type( + left_data_type, + right_data_type, + &BUILTIN_FUNCTIONS.default_cast_rules, + ); + if let Some(data_type) = common_data_type { + let (left_is_adjusted, left_constant) = + adjust_scalar(left.constant.value.clone(), data_type.clone()); + let (right_is_adjusted, right_constant) = + adjust_scalar(right.constant.value.clone(), data_type.clone()); + if left_is_adjusted && right_is_adjusted { + left.constant = left_constant; + right.constant = right_constant; + } + } else { + return Ok(MergeResult::All); + } + } + let merge_result = match left.op { ComparisonOp::Equal => match right.op { ComparisonOp::Equal => match left.constant == right.constant { true => MergeResult::Left, @@ -358,7 +382,8 @@ impl<'a> InferFilterOptimizer<'a> { false => MergeResult::Right, }, }, - } + }; + Ok(merge_result) } fn find(parent: &mut [usize], x: usize) -> usize { @@ -376,7 +401,7 @@ impl<'a> InferFilterOptimizer<'a> { } } - fn derive_predicates(&mut self) -> Vec { + fn derive_predicates(&mut self) -> Result> { let mut result = vec![]; let num_exprs = self.exprs.len(); @@ -407,7 +432,7 @@ impl<'a> InferFilterOptimizer<'a> { let expr = self.exprs[parent_index].clone(); let predicates = self.expr_predicates[index].clone(); for predicate in predicates { - self.add_expr_predicate(&expr, predicate); + self.add_expr_predicate(&expr, predicate)?; } } } @@ -455,7 +480,7 @@ impl<'a> InferFilterOptimizer<'a> { } } - result + Ok(result) } fn derive_remaining_predicates(&self, predicates: Vec) -> Vec { diff --git a/tests/sqllogictests/suites/mode/standalone/explain/infer_filter.test b/tests/sqllogictests/suites/mode/standalone/explain/infer_filter.test index 56104b7065c9..16cdfd364980 100644 --- a/tests/sqllogictests/suites/mode/standalone/explain/infer_filter.test +++ b/tests/sqllogictests/suites/mode/standalone/explain/infer_filter.test @@ -910,3 +910,67 @@ drop table if exists t2; statement ok drop table if exists t3; + +# merge predicates with different data types. +statement ok +drop table if exists t1; + +statement ok +drop table if exists t2; + +statement ok +create table t1(id BIGINT NOT NULL); + +statement ok +create table t2(id BIGINT UNSIGNED NULL); + +statement ok +insert into t1 values(869550529); + +statement ok +insert into t2 values(869550529); + +query T +explain SELECT * FROM t1 inner JOIN t2 on t1.id = t2.id where t2.id = 869550529; +---- +HashJoin +├── output columns: [t1.id (#0), t2.id (#1)] +├── join type: INNER +├── build keys: [CAST(t2.id (#1) AS Int64 NULL)] +├── probe keys: [CAST(t1.id (#0) AS Int64 NULL)] +├── filters: [] +├── estimated rows: 1.00 +├── Filter(Build) +│ ├── output columns: [t2.id (#1)] +│ ├── filters: [is_true(t2.id (#1) = 869550529)] +│ ├── estimated rows: 1.00 +│ └── TableScan +│ ├── table: default.default.t2 +│ ├── output columns: [id (#1)] +│ ├── read rows: 1 +│ ├── read size: < 1 KiB +│ ├── partitions total: 1 +│ ├── partitions scanned: 1 +│ ├── pruning stats: [segments: , blocks: ] +│ ├── push downs: [filters: [is_true(t2.id (#1) = 869550529)], limit: NONE] +│ └── estimated rows: 1.00 +└── Filter(Probe) + ├── output columns: [t1.id (#0)] + ├── filters: [t1.id (#0) = 869550529] + ├── estimated rows: 1.00 + └── TableScan + ├── table: default.default.t1 + ├── output columns: [id (#0)] + ├── read rows: 1 + ├── read size: < 1 KiB + ├── partitions total: 1 + ├── partitions scanned: 1 + ├── pruning stats: [segments: , blocks: ] + ├── push downs: [filters: [t1.id (#0) = 869550529], limit: NONE] + └── estimated rows: 1.00 + +statement ok +drop table if exists t1; + +statement ok +drop table if exists t2; diff --git a/tests/sqllogictests/suites/mode/standalone/explain_native/infer_filter.test b/tests/sqllogictests/suites/mode/standalone/explain_native/infer_filter.test index 4d01a9ecd4ac..3ef10d3bffe6 100644 --- a/tests/sqllogictests/suites/mode/standalone/explain_native/infer_filter.test +++ b/tests/sqllogictests/suites/mode/standalone/explain_native/infer_filter.test @@ -729,3 +729,59 @@ drop table if exists t2; statement ok drop table if exists t3; + +# merge predicates with different data types. +statement ok +drop table if exists t1; + +statement ok +drop table if exists t2; + +statement ok +create table t1(id BIGINT NOT NULL); + +statement ok +create table t2(id BIGINT UNSIGNED NULL); + +statement ok +insert into t1 values(869550529); + +statement ok +insert into t2 values(869550529); + +query T +explain SELECT * FROM t1 inner JOIN t2 on t1.id = t2.id where t2.id = 869550529; +---- +HashJoin +├── output columns: [t1.id (#0), t2.id (#1)] +├── join type: INNER +├── build keys: [CAST(t2.id (#1) AS Int64 NULL)] +├── probe keys: [CAST(t1.id (#0) AS Int64 NULL)] +├── filters: [] +├── estimated rows: 1.00 +├── TableScan(Build) +│ ├── table: default.default.t2 +│ ├── output columns: [id (#1)] +│ ├── read rows: 1 +│ ├── read size: < 1 KiB +│ ├── partitions total: 1 +│ ├── partitions scanned: 1 +│ ├── pruning stats: [segments: , blocks: ] +│ ├── push downs: [filters: [is_true(t2.id (#1) = 869550529)], limit: NONE] +│ └── estimated rows: 1.00 +└── TableScan(Probe) + ├── table: default.default.t1 + ├── output columns: [id (#0)] + ├── read rows: 1 + ├── read size: < 1 KiB + ├── partitions total: 1 + ├── partitions scanned: 1 + ├── pruning stats: [segments: , blocks: ] + ├── push downs: [filters: [t1.id (#0) = 869550529], limit: NONE] + └── estimated rows: 1.00 + +statement ok +drop table if exists t1; + +statement ok +drop table if exists t2;