Skip to content

Commit

Permalink
fix: window in scalar subquery returns wrong results (#16567)
Browse files Browse the repository at this point in the history
* fix: window in scalar subquery returns wrong results

* fix cluster
  • Loading branch information
xudong963 authored Oct 9, 2024
1 parent 5f6c413 commit 530de62
Show file tree
Hide file tree
Showing 12 changed files with 577 additions and 798 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ impl SubqueryRewriter {
Arc::new(left.clone()),
Arc::new(flatten_plan),
);
Ok((s_expr, UnnestResult::SingleJoin { output_index: None }))
Ok((s_expr, UnnestResult::SingleJoin))
}
SubqueryType::Exists | SubqueryType::NotExists => {
if is_conjunctive_predicate {
Expand Down
237 changes: 46 additions & 191 deletions src/query/sql/src/planner/optimizer/decorrelate/subquery_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use crate::optimizer::RelExpr;
use crate::optimizer::SExpr;
use crate::plans::Aggregate;
use crate::plans::AggregateFunction;
use crate::plans::AggregateMode;
use crate::plans::BoundColumnRef;
use crate::plans::CastExpr;
use crate::plans::ComparisonOp;
Expand Down Expand Up @@ -61,7 +60,7 @@ pub enum UnnestResult {
// Semi/Anti Join, Cross join for EXISTS
SimpleJoin { output_index: Option<IndexType> },
MarkJoin { marker_index: IndexType },
SingleJoin { output_index: Option<IndexType> },
SingleJoin,
}

pub struct FlattenInfo {
Expand Down Expand Up @@ -164,6 +163,17 @@ impl SubqueryRewriter {
Ok(SExpr::create_unary(Arc::new(plan.into()), Arc::new(input)))
}

RelOperator::Sort(mut sort) => {
let mut input = self.rewrite(s_expr.child(0)?)?;
for item in sort.window_partition.iter_mut() {
let res = self.try_rewrite_subquery(&item.scalar, &input, false)?;
input = res.1;
item.scalar = res.0;
}

Ok(SExpr::create_unary(Arc::new(sort.into()), Arc::new(input)))
}

RelOperator::Join(_) | RelOperator::UnionAll(_) | RelOperator::MaterializedCte(_) => {
Ok(SExpr::create_binary(
Arc::new(s_expr.plan().clone()),
Expand All @@ -172,13 +182,12 @@ impl SubqueryRewriter {
))
}

RelOperator::Limit(_)
| RelOperator::Sort(_)
| RelOperator::Udf(_)
| RelOperator::AsyncFunction(_) => Ok(SExpr::create_unary(
Arc::new(s_expr.plan().clone()),
Arc::new(self.rewrite(s_expr.child(0)?)?),
)),
RelOperator::Limit(_) | RelOperator::Udf(_) | RelOperator::AsyncFunction(_) => {
Ok(SExpr::create_unary(
Arc::new(s_expr.plan().clone()),
Arc::new(self.rewrite(s_expr.child(0)?)?),
))
}

RelOperator::DummyTableScan(_)
| RelOperator::Scan(_)
Expand Down Expand Up @@ -294,20 +303,15 @@ impl SubqueryRewriter {
}
let (index, name) = if let UnnestResult::MarkJoin { marker_index } = result {
(marker_index, marker_index.to_string())
} else if let UnnestResult::SingleJoin { output_index } = result {
if let Some(output_idx) = output_index {
// uncorrelated scalar subquery
(output_idx, "_if_scalar_subquery".to_string())
} else {
let mut output_column = subquery.output_column;
if let Some(index) = self.derived_columns.get(&output_column.index) {
output_column.index = *index;
}
(
output_column.index,
format!("scalar_subquery_{:?}", output_column.index),
)
} else if let UnnestResult::SingleJoin = result {
let mut output_column = subquery.output_column;
if let Some(index) = self.derived_columns.get(&output_column.index) {
output_column.index = *index;
}
(
output_column.index,
format!("scalar_subquery_{:?}", output_column.index),
)
} else {
let index = subquery.output_column.index;
(index, format!("subquery_{}", index))
Expand Down Expand Up @@ -423,7 +427,26 @@ impl SubqueryRewriter {
is_conjunctive_predicate: bool,
) -> Result<(SExpr, UnnestResult)> {
match subquery.typ {
SubqueryType::Scalar => self.rewrite_uncorrelated_scalar_subquery(left, subquery),
SubqueryType::Scalar => {
let join_plan = Join {
non_equi_conditions: vec![],
join_type: JoinType::LeftSingle,
marker_index: None,
from_correlated_subquery: false,
equi_conditions: vec![],
need_hold_hash_table: false,
is_lateral: false,
single_to_inner: None,
build_side_cache_info: None,
}
.into();
let s_expr = SExpr::create_binary(
Arc::new(join_plan),
Arc::new(left.clone()),
Arc::new(*subquery.subquery.clone()),
);
Ok((s_expr, UnnestResult::SingleJoin))
}
SubqueryType::Exists | SubqueryType::NotExists => {
let mut subquery_expr = *subquery.subquery.clone();
// Wrap Limit to current subquery
Expand Down Expand Up @@ -617,174 +640,6 @@ impl SubqueryRewriter {
_ => unreachable!(),
}
}

fn rewrite_uncorrelated_scalar_subquery(
&mut self,
left: &SExpr,
subquery: &SubqueryExpr,
) -> Result<(SExpr, UnnestResult)> {
// Use cross join which brings chance to push down filter under cross join.
// Such as `SELECT * FROM c WHERE c_id=(SELECT max(c_id) FROM o WHERE ship='WA');`
// We can push down `c_id = max(c_id)` to cross join then make it as inner join.
let join_plan = Join {
equi_conditions: JoinEquiCondition::new_conditions(vec![], vec![], vec![]),
non_equi_conditions: vec![],
join_type: JoinType::Cross,
marker_index: None,
from_correlated_subquery: false,
need_hold_hash_table: false,
is_lateral: false,
single_to_inner: None,
build_side_cache_info: None,
}
.into();

// For some cases, empty result set will be occur, we should return null instead of empty set.
// So let wrap an expression: `if(count()=0, null, any(subquery.output_column)`
let count_func = ScalarExpr::AggregateFunction(AggregateFunction {
span: subquery.span,
func_name: "count".to_string(),
distinct: false,
params: vec![],
args: vec![ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: subquery.output_column.clone(),
})],
return_type: Box::new(DataType::Number(NumberDataType::UInt64)),
display_name: "count".to_string(),
});
let any_func = ScalarExpr::AggregateFunction(AggregateFunction {
span: subquery.span,
func_name: "any".to_string(),
distinct: false,
params: vec![],
return_type: subquery.output_column.data_type.clone(),
args: vec![ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: subquery.output_column.clone(),
})],
display_name: "any".to_string(),
});
// Add `count_func` and `any_func` to metadata
let count_idx = self.metadata.write().add_derived_column(
"_count_scalar_subquery".to_string(),
DataType::Number(NumberDataType::UInt64),
None,
);
let any_idx = self.metadata.write().add_derived_column(
"_any_scalar_subquery".to_string(),
*subquery.output_column.data_type.clone(),
None,
);
// Aggregate operator
let agg = SExpr::create_unary(
Arc::new(
Aggregate {
mode: AggregateMode::Initial,
group_items: vec![],
aggregate_functions: vec![
ScalarItem {
scalar: count_func,
index: count_idx,
},
ScalarItem {
scalar: any_func,
index: any_idx,
},
],
..Default::default()
}
.into(),
),
Arc::new(*subquery.subquery.clone()),
);

let limit = SExpr::create_unary(
Arc::new(
Limit {
limit: Some(1),
offset: 0,
before_exchange: false,
}
.into(),
),
Arc::new(agg),
);

// Wrap expression
let count_col_ref = ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: ColumnBindingBuilder::new(
"_count_scalar_subquery".to_string(),
count_idx,
Box::new(DataType::Number(NumberDataType::UInt64)),
Visibility::Visible,
)
.build(),
});
let any_col_ref = ScalarExpr::BoundColumnRef(BoundColumnRef {
span: None,
column: ColumnBindingBuilder::new(
"_any_scalar_subquery".to_string(),
any_idx,
subquery.output_column.data_type.clone(),
Visibility::Visible,
)
.build(),
});
let eq_func = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "eq".to_string(),
params: vec![],
arguments: vec![
count_col_ref,
ScalarExpr::ConstantExpr(ConstantExpr {
span: None,
value: Scalar::Number(NumberScalar::UInt8(0)),
}),
],
});
// If function
let if_func = ScalarExpr::FunctionCall(FunctionCall {
span: None,
func_name: "if".to_string(),
params: vec![],
arguments: vec![
eq_func,
ScalarExpr::ConstantExpr(ConstantExpr {
span: None,
value: Scalar::Null,
}),
any_col_ref,
],
});
let if_func_idx = self.metadata.write().add_derived_column(
"_if_scalar_subquery".to_string(),
*subquery.output_column.data_type.clone(),
None,
);
let scalar_expr = SExpr::create_unary(
Arc::new(
EvalScalar {
items: vec![ScalarItem {
scalar: if_func,
index: if_func_idx,
}],
}
.into(),
),
Arc::new(limit),
);

let s_expr = SExpr::create_binary(
Arc::new(join_plan),
Arc::new(left.clone()),
Arc::new(scalar_expr),
);
Ok((s_expr, UnnestResult::SingleJoin {
output_index: Some(if_func_idx),
}))
}
}

pub fn check_child_expr_in_subquery(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,17 @@ pub fn try_push_down_filter_join(s_expr: &SExpr, metadata: MetadataRef) -> Resul
}
JoinPredicate::Other(_) => original_predicates.push(predicate),
JoinPredicate::Both { is_equal_op, .. } => {
if matches!(join.join_type, JoinType::Inner | JoinType::Cross) {
if matches!(join.join_type, JoinType::Inner | JoinType::Cross)
|| join.single_to_inner.is_some()
{
if is_equal_op {
push_down_predicates.push(predicate);
} else {
non_equi_predicates.push(predicate);
}
join.join_type = JoinType::Inner;
if join.join_type == JoinType::Cross {
join.join_type = JoinType::Inner;
}
} else {
original_predicates.push(predicate);
}
Expand Down
4 changes: 4 additions & 0 deletions src/query/sql/src/planner/semantic/distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,16 @@ impl DistinctToGroupBy {
distinct,
name,
args,
window,
..
},
},
alias,
} = &select_list[0]
{
if window.is_some() {
return;
}
let sub_query_name = "_distinct_group_by_subquery";
if ((name.name.to_ascii_lowercase() == "count" && *distinct)
|| name.name.to_ascii_lowercase() == "count_distinct")
Expand Down
Loading

0 comments on commit 530de62

Please sign in to comment.