Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(query): support subquery in pivot #16631

Merged
merged 5 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/query/ast/src/ast/format/syntax/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ pub(crate) fn pretty_table(table: TableReference) -> RcDoc<'static> {
lateral,
subquery,
alias,
pivot,
unpivot,
} => (if lateral {
RcDoc::text("LATERAL")
} else {
Expand All @@ -379,6 +381,16 @@ pub(crate) fn pretty_table(table: TableReference) -> RcDoc<'static> {
RcDoc::text(format!(" AS {alias}"))
} else {
RcDoc::nil()
})
.append(if let Some(pivot) = pivot {
RcDoc::text(format!(" {pivot}"))
} else {
RcDoc::nil()
})
.append(if let Some(unpivot) = unpivot {
RcDoc::text(format!(" {unpivot}"))
} else {
RcDoc::nil()
}),
TableReference::TableFunction {
span: _,
Expand Down
31 changes: 29 additions & 2 deletions src/query/ast/src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,17 +533,30 @@ impl Display for TimeTravelPoint {
}
}

#[derive(Debug, Clone, PartialEq, Drive, DriveMut)]
pub enum PivotValues {
ColumnValues(Vec<Expr>),
Subquery(Box<Query>),
}

#[derive(Debug, Clone, PartialEq, Drive, DriveMut)]
pub struct Pivot {
pub aggregate: Expr,
pub value_column: Identifier,
pub values: Vec<Expr>,
pub values: PivotValues,
}

impl Display for Pivot {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "PIVOT({} FOR {} IN (", self.aggregate, self.value_column)?;
write_comma_separated_list(f, &self.values)?;
match &self.values {
PivotValues::ColumnValues(column_values) => {
write_comma_separated_list(f, column_values)?;
}
PivotValues::Subquery(subquery) => {
write!(f, "{}", subquery)?;
}
}
write!(f, "))")?;
Ok(())
}
Expand Down Expand Up @@ -740,6 +753,8 @@ pub enum TableReference {
lateral: bool,
subquery: Box<Query>,
alias: Option<TableAlias>,
pivot: Option<Box<Pivot>>,
unpivot: Option<Box<Unpivot>>,
},
Join {
span: Span,
Expand All @@ -757,13 +772,15 @@ impl TableReference {
pub fn pivot(&self) -> Option<&Pivot> {
match self {
TableReference::Table { pivot, .. } => pivot.as_ref().map(|b| b.as_ref()),
TableReference::Subquery { pivot, .. } => pivot.as_ref().map(|b| b.as_ref()),
_ => None,
}
}

pub fn unpivot(&self) -> Option<&Unpivot> {
match self {
TableReference::Table { unpivot, .. } => unpivot.as_ref().map(|b| b.as_ref()),
TableReference::Subquery { unpivot, .. } => unpivot.as_ref().map(|b| b.as_ref()),
_ => None,
}
}
Expand Down Expand Up @@ -862,6 +879,8 @@ impl Display for TableReference {
lateral,
subquery,
alias,
pivot,
unpivot,
} => {
if *lateral {
write!(f, "LATERAL ")?;
Expand All @@ -870,6 +889,14 @@ impl Display for TableReference {
if let Some(alias) = alias {
write!(f, " AS {alias}")?;
}

if let Some(pivot) = pivot {
write!(f, " {pivot}")?;
}

if let Some(unpivot) = unpivot {
write!(f, " {unpivot}")?;
}
}
TableReference::Join { span: _, join } => {
write!(f, "{}", join.left)?;
Expand Down
2 changes: 2 additions & 0 deletions src/query/ast/src/ast/statements/merge_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ impl MergeSource {
lateral: false,
subquery: query.clone(),
alias: Some(source_alias.clone()),
pivot: None,
unpivot: None,
},
Self::Table {
catalog,
Expand Down
69 changes: 45 additions & 24 deletions src/query/ast/src/parser/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,8 @@ pub enum TableReferenceElement {
lateral: bool,
subquery: Box<Query>,
alias: Option<TableAlias>,
pivot: Option<Box<Pivot>>,
unpivot: Option<Box<Unpivot>>,
},
// [NATURAL] [INNER|OUTER|CROSS|...] JOIN
Join {
Expand All @@ -736,28 +738,6 @@ pub enum TableReferenceElement {
}

pub fn table_reference_element(i: Input) -> IResult<WithSpan<TableReferenceElement>> {
// PIVOT(expr FOR col IN (ident, ...))
let pivot = map(
rule! {
PIVOT ~ "(" ~ #expr ~ FOR ~ #ident ~ IN ~ "(" ~ #comma_separated_list1(expr) ~ ")" ~ ")"
},
|(_pivot, _, aggregate, _for, value_column, _in, _, values, _, _)| Pivot {
aggregate,
value_column,
values,
},
);
// UNPIVOT(ident for ident IN (ident, ...))
let unpivot = map(
rule! {
UNPIVOT ~ "(" ~ #ident ~ FOR ~ #ident ~ IN ~ "(" ~ #comma_separated_list1(ident) ~ ")" ~ ")"
},
|(_unpivot, _, value_column, _for, column_name, _in, _, names, _, _)| Unpivot {
value_column,
column_name,
names,
},
);
let aliased_table = map(
rule! {
#dot_separated_idents_1_to_3 ~ #temporal_clause? ~ #with_options? ~ #table_alias? ~ #pivot? ~ #unpivot? ~ SAMPLE? ~ (BLOCK ~ "(" ~ #expr ~ ")")? ~ (ROW ~ "(" ~ #expr ~ ROWS? ~ ")")?
Expand Down Expand Up @@ -825,12 +805,14 @@ pub fn table_reference_element(i: Input) -> IResult<WithSpan<TableReferenceEleme
);
let subquery = map(
rule! {
LATERAL? ~ "(" ~ #query ~ ")" ~ #table_alias?
LATERAL? ~ "(" ~ #query ~ ")" ~ #table_alias? ~ #pivot? ~ #unpivot?
},
|(lateral, _, subquery, _, alias)| TableReferenceElement::Subquery {
|(lateral, _, subquery, _, alias, pivot, unpivot)| TableReferenceElement::Subquery {
lateral: lateral.is_some(),
subquery: Box::new(subquery),
alias,
pivot: pivot.map(Box::new),
unpivot: unpivot.map(Box::new),
},
);

Expand Down Expand Up @@ -869,6 +851,41 @@ pub fn table_reference_element(i: Input) -> IResult<WithSpan<TableReferenceEleme
Ok((rest, WithSpan { span, elem }))
}

// PIVOT(expr FOR col IN (ident, ... | subquery))
fn pivot(i: Input) -> IResult<Pivot> {
map(
rule! {
PIVOT ~ "(" ~ #expr ~ FOR ~ #ident ~ IN ~ "(" ~ #pivot_values ~ ")" ~ ")"
},
|(_pivot, _, aggregate, _for, value_column, _in, _, values, _, _)| Pivot {
aggregate,
value_column,
values,
},
)(i)
}

// UNPIVOT(ident for ident IN (ident, ...))
fn unpivot(i: Input) -> IResult<Unpivot> {
map(
rule! {
UNPIVOT ~ "(" ~ #ident ~ FOR ~ #ident ~ IN ~ "(" ~ #comma_separated_list1(ident) ~ ")" ~ ")"
},
|(_unpivot, _, value_column, _for, column_name, _in, _, names, _, _)| Unpivot {
value_column,
column_name,
names,
},
)(i)
}

fn pivot_values(i: Input) -> IResult<PivotValues> {
alt((
map(comma_separated_list1(expr), PivotValues::ColumnValues),
map(query, |q| PivotValues::Subquery(Box::new(q))),
))(i)
}

fn get_table_sample(
sample: Option<&Token>,
block_level_sample: Option<(&Token, &Token, Expr, &Token)>,
Expand Down Expand Up @@ -966,11 +983,15 @@ impl<'a, I: Iterator<Item = WithSpan<'a, TableReferenceElement>>> PrattParser<I>
lateral,
subquery,
alias,
pivot,
unpivot,
} => TableReference::Subquery {
span: transform_span(input.span.tokens),
lateral,
subquery,
alias,
pivot,
unpivot,
},
TableReferenceElement::Stage {
location,
Expand Down
4 changes: 4 additions & 0 deletions src/query/ast/tests/it/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,11 @@ fn test_query() {
r#"SELECT * FROM (((SELECT *) EXCEPT (SELECT *))) foo"#,
r#"SELECT * FROM (SELECT * FROM xyu ORDER BY x, y) AS xyu"#,
r#"select * from monthly_sales pivot(sum(amount) for month in ('JAN', 'FEB', 'MAR', 'APR')) order by empid"#,
r#"select * from (select * from monthly_sales) pivot(sum(amount) for month in ('JAN', 'FEB', 'MAR', 'APR')) order by empid"#,
r#"select * from monthly_sales pivot(sum(amount) for month in (select distinct month from monthly_sales)) order by empid"#,
r#"select * from (select * from monthly_sales) pivot(sum(amount) for month in ((select distinct month from monthly_sales))) order by empid"#,
r#"select * from monthly_sales_1 unpivot(sales for month in (jan, feb, mar, april)) order by empid"#,
r#"select * from (select * from monthly_sales_1) unpivot(sales for month in (jan, feb, mar, april)) order by empid"#,
r#"select * from range(1, 2)"#,
r#"select sum(a) over w from customer window w as (partition by a order by b)"#,
r#"select a, sum(a) over w, sum(a) over w1, sum(a) over w2 from t1 window w as (partition by a), w2 as (w1 rows current row), w1 as (w order by a) order by a"#,
Expand Down
Loading
Loading