diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 398f59e35d10..818c36c23a57 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1658,10 +1658,10 @@ mod tests { use datafusion_common::{Constraint, Constraints}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_functions_aggregate::expr_fn::count_distinct; + use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index b539544d8372..a9bca1830dfa 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -92,6 +92,7 @@ use datafusion_expr::{ DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -1854,6 +1855,33 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn( + datafusion_expr::AggregateFunction::ArrayAgg, + ) if !distinct && order_by.is_none() => { + let sort_exprs = order_by.clone().unwrap_or(vec![]); + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; + let ordering_reqs: Vec = + physical_sort_exprs.clone().unwrap_or(vec![]); + let agg_expr = udaf::create_aggregate_expr( + &array_agg_udaf(), + &physical_args, + args, + &sort_exprs, + &ordering_reqs, + physical_input_schema, + name, + ignore_nulls, + *distinct, + )?; + (agg_expr, filter, physical_sort_exprs) + } AggregateFunctionDefinition::BuiltIn(fun) => { let physical_sort_exprs = match order_by { Some(exprs) => Some(create_physical_sort_exprs( diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index fa364c5f2a65..2d8736b9c47d 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{count, sum}; +use datafusion_functions_aggregate::expr_fn::{array_agg, count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a87412ee6356..9feff05dcb32 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -169,18 +169,6 @@ pub fn max(expr: Expr) -> Expr { )) } -/// Create an expression to represent the array_agg() aggregate function -pub fn array_agg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ArrayAgg, - vec![expr], - false, - None, - None, - None, - )) -} - /// Create an expression to represent the avg() aggregate function pub fn avg(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 169436145aae..7198f17a9df9 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -98,6 +98,9 @@ pub struct StateFieldsArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, + /// If the input type is nullable. + pub input_nullable: bool, + /// The return type of the aggregate function. pub return_type: &'a DataType, diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs similarity index 55% rename from datafusion/physical-expr/src/aggregate/array_agg.rs rename to datafusion/functions-aggregate/src/array_agg.rs index a23ba07de44a..a0cedf5817ff 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -17,102 +17,118 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +use arrow::array::{Array, ArrayRef}; +use arrow::datatypes::DataType; +use arrow_schema::Field; + use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use std::any::Any; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr::AggregateFunctionDefinition; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::AggregateUDFImpl; +use datafusion_expr::Expr; +use datafusion_expr::{Accumulator, Signature, Volatility}; use std::sync::Arc; -/// ARRAY_AGG aggregate expression +make_udaf_expr_and_func!( + ArrayAgg, + array_agg, + expression, + "input values, including nulls, concatenated into an array", + array_agg_udaf +); + #[derive(Debug)] +/// ARRAY_AGG aggregate expression pub struct ArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// If the input expression can have NULLs - nullable: bool, + signature: Signature, + alias: Vec, } -impl ArrayAgg { - /// Create a new ArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - nullable: bool, - ) -> Self { +impl Default for ArrayAgg { + fn default() -> Self { Self { - name: name.into(), - input_data_type: data_type, - expr, - nullable, + signature: Signature::any(1, Volatility::Immutable), + alias: vec!["array_agg".to_string()], } } } -impl AggregateExpr for ArrayAgg { - fn as_any(&self) -> &dyn Any { +impl AggregateUDFImpl for ArrayAgg { + fn as_any(&self) -> &dyn std::any::Any { self } - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )) + fn name(&self) -> &str { + "ARRAY_AGG" } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(ArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) + fn aliases(&self) -> &[String] { + &self.alias } - fn state_fields(&self) -> Result> { + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn state_fields( + &self, + args: datafusion_expr::function::StateFieldsArgs, + ) -> Result> { Ok(vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + format_state_name(args.name, "array_agg"), + Field::new("item", args.input_type.clone(), true), + args.input_nullable, )]) } - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)) } - fn name(&self) -> &str { - &self.name + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Identical } -} -impl PartialEq for ArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn simplify( + &self, + ) -> Option { + let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { + if aggregate_function.order_by.is_some() || aggregate_function.distinct { + Ok(Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::BuiltIn( + datafusion_expr::aggregate_function::AggregateFunction::ArrayAgg, + ), + args: aggregate_function.args, + distinct: aggregate_function.distinct, + filter: aggregate_function.filter, + order_by: aggregate_function.order_by, + null_treatment: aggregate_function.null_treatment, + })) + } else { + Ok(Expr::AggregateFunction(aggregate_function)) + } + }; + + Some(Box::new(simplify)) } } #[derive(Debug)] -pub(crate) struct ArrayAggAccumulator { +pub struct ArrayAggAccumulator { values: Vec, datatype: DataType, } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index dd38e3487264..066ab77eadcf 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -440,6 +440,7 @@ impl AggregateUDFImpl for LastValue { let StateFieldsArgs { name, input_type, + input_nullable: _, return_type: _, ordering_fields, is_distinct: _, diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 260d6dab31b9..e0556d64f768 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -56,6 +56,7 @@ pub mod macros; pub mod approx_distinct; +pub mod array_agg; pub mod count; pub mod covariance; pub mod first_last; @@ -86,6 +87,7 @@ pub mod expr_fn { pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; + pub use super::array_agg::array_agg; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; pub use super::bit_and_or_xor::bit_xor; @@ -117,6 +119,7 @@ pub mod expr_fn { /// Returns all default aggregate functions pub fn all_default_aggregate_functions() -> Vec> { vec![ + array_agg::array_agg_udaf(), first_last::first_value_udaf(), first_last::last_value_udaf(), covariance::covar_samp_udaf(), @@ -177,7 +180,8 @@ mod tests { for func in all_default_aggregate_functions() { // TODO: remove this // These functions are in intermidiate migration state, skip them - if func.name().to_lowercase() == "count" { + let name_lower_case = func.name().to_lowercase(); + if name_lower_case == "count" || name_lower_case == "array_agg" { continue; } assert!( diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 432267e045b2..1d1f2c44d5c9 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -87,6 +87,7 @@ pub fn create_aggregate_expr( ordering_fields, is_distinct, input_type: input_exprs_types[0].clone(), + input_nullable: input_phy_exprs[0].nullable(schema)?, })) } @@ -248,6 +249,7 @@ pub struct AggregateFunctionExpr { ordering_fields: Vec, is_distinct: bool, input_type: DataType, + input_nullable: bool, } impl AggregateFunctionExpr { @@ -276,6 +278,7 @@ impl AggregateExpr for AggregateFunctionExpr { let args = StateFieldsArgs { name: &self.name, input_type: &self.input_type, + input_nullable: self.input_nullable, return_type: &self.data_type, ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, @@ -285,7 +288,11 @@ impl AggregateExpr for AggregateFunctionExpr { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new( + &self.name, + self.data_type.clone(), + self.input_nullable, + )) } fn create_accumulator(&self) -> Result> { diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 53cfcfb033a1..dffddedbf5fd 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; @@ -71,7 +71,9 @@ pub fn create_aggregate_expr( let nullable = expr.nullable(input_schema)?; if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) + return internal_err!( + "ArrayAgg without ordering should be handled as UDAF" + ); } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, @@ -155,7 +157,7 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{try_cast, ArrayAgg, Avg, DistinctArrayAgg, Max, Min}; + use crate::expressions::{try_cast, Avg, DistinctArrayAgg, Max, Min}; use super::*; #[test] @@ -176,25 +178,6 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } let result_distinct = create_physical_agg_expr_for_test( &fun, @@ -212,7 +195,7 @@ mod tests { Field::new("item", data_type.clone(), true), true, ), - result_agg_phy_exprs.field().unwrap() + result_distinct.field().unwrap() ); } } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index f64c5b1fb260..a4aaa7d03951 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,7 +17,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; pub(crate) mod average; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 0020aa5f55b2..bc5056e3d7ff 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -35,7 +35,6 @@ mod try_cast; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::average::Avg; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index a9d3736dee08..b963113a82dc 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,10 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, - DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, - NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, - RankType, RowNumber, TryCastExpr, WindowShift, + Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, + Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, + NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, + RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -240,8 +240,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Grouping - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { distinct = true; protobuf::AggregateFunction::ArrayAgg diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b3966c3f0204..95e75f825cfd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,7 +60,7 @@ use datafusion_expr::{ WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::expr_fn::{ - bit_and, bit_or, bit_xor, bool_and, bool_or, + array_agg, bit_and, bit_or, bit_xor, bool_and, bool_or, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -675,6 +675,7 @@ async fn roundtrip_expr_api() -> Result<()> { string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), bool_and(lit(true)), bool_or(lit(true)), + array_agg(lit(1)), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 8b64ccfb52cb..4db19b8381b0 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -596,10 +596,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// TODO: this should likely be done in ArrayAgg::simplify when it is moved to a UDAF fn simplify_array_index_expr(expr: &Expr, index: &Expr) -> Option { fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( + match agg_func.func_def { + datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( AggregateFunction::ArrayAgg, - ) + ) => true, + datafusion_expr::expr::AggregateFunctionDefinition::UDF(ref udf) => { + udf.name() == "ARRAY_AGG" + } + _ => false, + } } match expr { Expr::AggregateFunction(agg_func) if is_array_agg(agg_func) => {