Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ArrayAgg Builtin in favor of UDF #11611

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ async fn unnest_with_redundant_columns() -> Result<()> {
let expected = vec![
"Projection: shapes.shape_id [shape_id:UInt32]",
" Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
" TableScan: shapes projection=[shape_id] [shape_id:UInt32]",
];

Expand Down Expand Up @@ -1973,7 +1973,7 @@ async fn test_array_agg() -> Result<()> {

let expected = [
"+-------------------------------------+",
"| ARRAY_AGG(test.a) |",
"| array_agg(test.a) |",
"+-------------------------------------+",
"| [abcDEF, abc123, CBAdef, 123AbcDef] |",
"+-------------------------------------+",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
assert_eq!(
*actual[0].schema(),
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
"array_agg(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, true),
true
),])
Expand Down
16 changes: 2 additions & 14 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

//! Aggregate function module contains all built-in aggregate functions definitions

use std::sync::Arc;
use std::{fmt, str::FromStr};

use crate::utils;
use crate::{type_coercion::aggregates::*, Signature, Volatility};

use arrow::datatypes::{DataType, Field};
use arrow::datatypes::DataType;
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};

use strum_macros::EnumIter;
Expand All @@ -37,8 +36,6 @@ pub enum AggregateFunction {
Min,
/// Maximum
Max,
/// Aggregation into an array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

ArrayAgg,
}

impl AggregateFunction {
Expand All @@ -47,7 +44,6 @@ impl AggregateFunction {
match self {
Min => "MIN",
Max => "MAX",
ArrayAgg => "ARRAY_AGG",
}
}
}
Expand All @@ -65,7 +61,6 @@ impl FromStr for AggregateFunction {
// general
"max" => AggregateFunction::Max,
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
_ => {
return plan_err!("There is no built-in function named {name}");
}
Expand All @@ -80,7 +75,7 @@ impl AggregateFunction {
pub fn return_type(
&self,
input_expr_types: &[DataType],
input_expr_nullable: &[bool],
_input_expr_nullable: &[bool],
) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
Expand All @@ -105,11 +100,6 @@ impl AggregateFunction {
// The coerced_data_types is same with input_types.
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
input_expr_nullable[0],
)))),
}
}

Expand All @@ -118,7 +108,6 @@ impl AggregateFunction {
pub fn nullable(&self) -> Result<bool> {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
AggregateFunction::ArrayAgg => Ok(true),
}
}
}
Expand All @@ -128,7 +117,6 @@ impl AggregateFunction {
pub fn signature(&self) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match self {
AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
Expand Down
7 changes: 1 addition & 6 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ pub fn coerce_types(
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Min | AggregateFunction::Max => {
// min and max support the dictionary data type
// unpack the dictionary to get the value
Expand Down Expand Up @@ -360,11 +359,7 @@ mod tests {

// test count, array_agg, approx_distinct, min, max.
// the coerced types is same with input types
let funs = vec![
AggregateFunction::ArrayAgg,
AggregateFunction::Min,
AggregateFunction::Max,
];
let funs = vec![AggregateFunction::Min, AggregateFunction::Max];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Decimal128(10, 2)],
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
pub enum ReversedUDAF {
/// The expression is the same as the original expression, like SUM, COUNT
Identical,
/// The expression does not support reverse calculation, like ArrayAgg
/// The expression does not support reverse calculation
NotSupported,
/// The expression is different from the original expression
Reversed(Arc<AggregateUDF>),
Expand Down
7 changes: 2 additions & 5 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,12 @@ make_udaf_expr_and_func!(
/// ARRAY_AGG aggregate expression
pub struct ArrayAgg {
signature: Signature,
alias: Vec<String>,
}

impl Default for ArrayAgg {
fn default() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
alias: vec!["array_agg".to_string()],
}
}
}
Expand All @@ -67,13 +65,12 @@ impl AggregateUDFImpl for ArrayAgg {
self
}

// TODO: change name to lowercase
fn name(&self) -> &str {
"ARRAY_AGG"
"array_agg"
}

fn aliases(&self) -> &[String] {
&self.alias
&[]
}

fn signature(&self) -> &Signature {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-array/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl ExprPlanner for FieldAccessPlanner {

fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool {
if let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def {
return udf.name() == "ARRAY_AGG";
return udf.name() == "array_agg";
}

false
Expand Down
5 changes: 3 additions & 2 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,9 @@ impl AggregateExpr for AggregateFunctionExpr {
})
.collect::<Vec<_>>();
let mut name = self.name().to_string();
// TODO: Generalize order-by clause rewrite
if reverse_udf.name() == "ARRAY_AGG" {
// If the function is changed, we need to reverse order_by clause as well
// i.e. First(a order by b asc null first) -> Last(a order by b desc null last)
if self.fun().name() == reverse_udf.name() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess name checking is enough for now.

Introduce supports_rewrite_order_by for AggregateUDFImpl might add additional complexity without benefit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to eventually move this to a method in https://github.com/apache/datafusion/pull/11611 though I agree this is good for now. Maybe we can file a ticket to track

Copy link
Contributor Author

@jayzhan211 jayzhan211 Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @alamb
I come out a better idea of adding reverse_name (pair with reverse_udf) method instead of supports_rewrite_order_by which makes more sense to me. #11629

} else {
replace_order_by_clause(&mut name);
}
Expand Down
4 changes: 1 addition & 3 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::sync::Arc;

use arrow::datatypes::Schema;

use datafusion_common::{internal_err, Result};
use datafusion_common::Result;
use datafusion_expr::AggregateFunction;

use crate::expressions::{self};
Expand All @@ -56,7 +56,6 @@ pub fn create_aggregate_expr(
let data_type = input_phy_types[0].clone();
let input_phy_exprs = input_phy_exprs.to_vec();
Ok(match (fun, distinct) {
(AggregateFunction::ArrayAgg, _) => return internal_err!("not reachable"),
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
Arc::clone(&input_phy_exprs[0]),
name,
Expand Down Expand Up @@ -123,7 +122,6 @@ mod tests {
result_agg_phy_exprs.field().unwrap()
);
}
_ => {}
};
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ enum AggregateFunction {
// AVG = 3;
// COUNT = 4;
// APPROX_DISTINCT = 5;
ARRAY_AGG = 6;
// ARRAY_AGG = 6;
// VARIANCE = 7;
// VARIANCE_POP = 8;
// COVARIANCE = 9;
Expand Down
3 changes: 0 additions & 3 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
match agg_fun {
protobuf::AggregateFunction::Min => Self::Min,
protobuf::AggregateFunction::Max => Self::Max,
protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg,
}
}
}
Expand Down
2 changes: 0 additions & 2 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
match value {
AggregateFunction::Min => Self::Min,
AggregateFunction::Max => Self::Max,
AggregateFunction::ArrayAgg => Self::ArrayAgg,
}
}
}
Expand Down Expand Up @@ -386,7 +385,6 @@ pub fn serialize_expr(
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let aggr_function = match fun {
AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
};
Expand Down
16 changes: 8 additions & 8 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ query TT
explain select array_agg(c1 order by c2 desc, c3) from agg_order;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]]
01)Aggregate: groupBy=[[]], aggr=[[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]]
02)--TableScan: agg_order projection=[c1, c2, c3]
physical_plan
01)AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]
01)AggregateExec: mode=Final, gby=[], aggr=[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]
02)--CoalescePartitionsExec
03)----AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]
03)----AggregateExec: mode=Partial, gby=[], aggr=[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]
04)------SortExec: expr=[c2@1 DESC,c3@2 ASC NULLS LAST], preserve_partitioning=[true]
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2, c3], has_header=true
Expand Down Expand Up @@ -231,8 +231,8 @@ explain with A as (
) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id;
----
logical_plan
01)Projection: array_length(ARRAY_AGG(DISTINCT a.foo)), sum(DISTINCT Int64(1))
02)--Aggregate: groupBy=[[a.id]], aggr=[[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))]]
01)Projection: array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1))
02)--Aggregate: groupBy=[[a.id]], aggr=[[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))]]
03)----SubqueryAlias: a
04)------SubqueryAlias: a
05)--------Union
Expand All @@ -247,11 +247,11 @@ logical_plan
14)----------Projection: Int64(1) AS id, Int64(2) AS foo
15)------------EmptyRelation
physical_plan
01)ProjectionExec: expr=[array_length(ARRAY_AGG(DISTINCT a.foo)@1) as array_length(ARRAY_AGG(DISTINCT a.foo)), sum(DISTINCT Int64(1))@2 as sum(DISTINCT Int64(1))]
02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))]
01)ProjectionExec: expr=[array_length(array_agg(DISTINCT a.foo)@1) as array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1))@2 as sum(DISTINCT Int64(1))]
02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))]
03)----CoalesceBatchesExec: target_batch_size=8192
04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5
05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted
05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted
06)----------UnionExec
07)------------ProjectionExec: expr=[1 as id, 2 as foo]
08)--------------PlaceholderRowExec
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/binary_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@ Raphael R false false true true
NULL R NULL NULL NULL NULL

statement ok
drop table test;
drop table test;
Loading
Loading