diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index b5c58eff577c..5e08b47d5257 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1590,10 +1590,10 @@ mod tests { use datafusion_common::{Constraint, Constraints}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_functions_aggregate::expr_fn::count_distinct; + use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index fa364c5f2a65..2d8736b9c47d 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{count, sum}; +use datafusion_functions_aggregate::expr_fn::{array_agg, count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 84b791a3de05..3acf5f814984 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, true), - false + true ),]) ); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 967ccc0b0866..67e11bda9a0d 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -17,13 +17,12 @@ //! Aggregate function module contains all built-in aggregate functions definitions -use std::sync::Arc; use std::{fmt, str::FromStr}; use crate::utils; use crate::{type_coercion::aggregates::*, Signature, Volatility}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::DataType; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; use strum_macros::EnumIter; @@ -39,10 +38,6 @@ pub enum AggregateFunction { Max, /// Average Avg, - /// Aggregation into an array - ArrayAgg, - /// N'th value in a group according to some ordering - NthValue, /// Correlation Correlation, /// Grouping @@ -56,8 +51,6 @@ impl AggregateFunction { Min => "MIN", Max => "MAX", Avg => "AVG", - ArrayAgg => "ARRAY_AGG", - NthValue => "NTH_VALUE", Correlation => "CORR", Grouping => "GROUPING", } @@ -79,8 +72,6 @@ impl FromStr for AggregateFunction { "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, "min" => AggregateFunction::Min, - "array_agg" => AggregateFunction::ArrayAgg, - "nth_value" => AggregateFunction::NthValue, // statistical "corr" => AggregateFunction::Correlation, // other @@ -124,13 +115,7 @@ impl AggregateFunction { correlation_return_type(&coerced_data_types[0]) } AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), - AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( - "item", - coerced_data_types[0].clone(), - true, - )))), AggregateFunction::Grouping => Ok(DataType::Int32), - AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), } } } @@ -153,9 +138,7 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { - Signature::any(1, Volatility::Immutable) - } + AggregateFunction::Grouping => Signature::any(1, Volatility::Immutable), AggregateFunction::Min | AggregateFunction::Max => { let valid = STRINGS .iter() @@ -171,7 +154,6 @@ impl AggregateFunction { AggregateFunction::Avg => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } - AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a87412ee6356..9feff05dcb32 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -169,18 +169,6 @@ pub fn max(expr: Expr) -> Expr { )) } -/// Create an expression to represent the array_agg() aggregate function -pub fn array_agg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ArrayAgg, - vec![expr], - false, - None, - None, - None, - )) -} - /// Create an expression to represent the avg() aggregate function pub fn avg(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 428fc99070d2..920dfc51212c 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -95,7 +95,6 @@ pub fn coerce_types( check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; match agg_fun { - AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), AggregateFunction::Min | AggregateFunction::Max => { // min and max support the dictionary data type // unpack the dictionary to get the value @@ -131,7 +130,6 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), } } @@ -383,11 +381,7 @@ mod tests { // test count, array_agg, approx_distinct, min, max. // the coerced types is same with input types - let funs = vec![ - AggregateFunction::ArrayAgg, - AggregateFunction::Min, - AggregateFunction::Max, - ]; + let funs = vec![AggregateFunction::Min, AggregateFunction::Max]; let input_types = vec![ vec![DataType::Int32], vec![DataType::Decimal128(10, 2)], diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 26630a0352d5..caee13095c9a 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -40,6 +40,7 @@ path = "src/lib.rs" [dependencies] ahash = { workspace = true } arrow = { workspace = true } +arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } @@ -48,3 +49,6 @@ datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.14" sqlparser = { workspace = true } + +[dev-dependencies] +arrow-buffer = { workspace = true } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs new file mode 100644 index 000000000000..91d2970d2024 --- /dev/null +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under on +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use arrow_array::Array; +use arrow_schema::Field; +use arrow_schema::Fields; +use datafusion_common::cast::as_list_array; +use datafusion_common::not_impl_err; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::AggregateUDF; +use datafusion_expr::AggregateUDFImpl; +use datafusion_expr::{Accumulator, Signature, Volatility}; +use datafusion_physical_expr_common::sort_expr::limited_convert_logical_sort_exprs_to_physical; +use std::sync::Arc; + +use crate::array_agg_distinct::DistinctArrayAggAccumulator; +use crate::array_agg_ordered::OrderSensitiveArrayAggAccumulator; + +make_udaf_expr_and_func!( + ArrayAgg, + array_agg, + expression, + "Computes the nth value", + array_agg_udaf +); + +#[derive(Debug)] +/// ARRAY_AGG aggregate expression +pub struct ArrayAgg { + signature: Signature, + alias: Vec, + reverse: bool, +} + +impl Default for ArrayAgg { + fn default() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + alias: vec!["array_agg".to_string()], + reverse: false, + } + } +} + +impl AggregateUDFImpl for ArrayAgg { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "ARRAY_AGG" + } + + fn aliases(&self) -> &[String] { + &self.alias + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn state_fields( + &self, + args: datafusion_expr::function::StateFieldsArgs, + ) -> Result> { + let mut fields = vec![Field::new_list( + format_state_name(args.name, "array_agg"), + Field::new("item", args.input_type.clone(), true), + true, + )]; + if !args.ordering_fields.is_empty() { + fields.push(Field::new_list( + format_state_name(args.name, "array_agg_orderings"), + Field::new( + "item", + DataType::Struct(Fields::from(args.ordering_fields.to_vec())), + true, + ), + true, + )); + } + Ok(fields) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if !acc_args.sort_exprs.is_empty() && acc_args.is_distinct { + not_impl_err!("ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available") + } else if !acc_args.sort_exprs.is_empty() { + let ordering_req = limited_convert_logical_sort_exprs_to_physical( + acc_args.sort_exprs, + acc_args.schema, + )?; + + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; + + Ok(Box::new(OrderSensitiveArrayAggAccumulator::try_new( + acc_args.input_type, + &ordering_dtypes, + ordering_req, + self.reverse, + )?)) + } else if acc_args.is_distinct { + Ok(Box::new(DistinctArrayAggAccumulator::try_new( + acc_args.input_type, + )?)) + } else { + Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)) + } + } + + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { + datafusion_expr::utils::AggregateOrderSensitivity::HardRequirement + } + + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(Arc::new(AggregateUDF::from(Self { + signature: self.signature.clone(), + alias: self.alias.clone(), + reverse: !self.reverse, + }))) + } +} + +#[derive(Debug)] +pub(crate) struct ArrayAggAccumulator { + values: Vec, + datatype: DataType, +} + +impl ArrayAggAccumulator { + /// new array_agg accumulator based on given item data type + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: vec![], + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for ArrayAggAccumulator { + // Append value like Int64Array(1,2,3) + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + assert!(values.len() == 1, "array_agg can only take 1 param!"); + let val = values[0].clone(); + self.values.push(val); + Ok(()) + } + + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert!(states.len() == 1, "array_agg states must be singleton!"); + + let list_arr = as_list_array(&states[0])?; + for arr in list_arr.iter().flatten() { + self.values.push(arr); + } + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + // Transform Vec to ListArr + + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + let arr = ScalarValue::new_list(&[], &self.datatype); + return Ok(ScalarValue::List(arr)); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + let list_array = array_into_list_array(concated_array); + + Ok(ScalarValue::List(Arc::new(list_array))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + + self.datatype.size() + - std::mem::size_of_val(&self.datatype) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::expressions::column::Column; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + #[test] + fn test_array_agg_expr() -> Result<()> { + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal128(10, 2), + DataType::Utf8, + ]; + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &array_agg_udaf(), + &input_phy_exprs[0..1], + &[], + &[], + &[], + &input_schema, + "c1", + false, + false, + )?; + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), + result_agg_phy_exprs.field().unwrap() + ); + + let result_distinct = create_aggregate_expr( + &array_agg_udaf(), + &input_phy_exprs[0..1], + &[], + &[], + &[], + &input_schema, + "c1", + false, + true, + )?; + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), + result_agg_phy_exprs.field().unwrap() + ); + } + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/functions-aggregate/src/array_agg_distinct.rs similarity index 79% rename from datafusion/physical-expr/src/aggregate/array_agg_distinct.rs rename to datafusion/functions-aggregate/src/array_agg_distinct.rs index 244a44acdcb5..e1f59c0283a3 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/functions-aggregate/src/array_agg_distinct.rs @@ -17,106 +17,18 @@ //! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` -use std::any::Any; use std::collections::HashSet; use std::fmt::Debug; -use std::sync::Arc; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::DataType; use arrow_array::cast::AsArray; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; -/// Expression for a ARRAY_AGG(DISTINCT) aggregation. -#[derive(Debug)] -pub struct DistinctArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// If the input expression can have NULLs - nullable: bool, -} - -impl DistinctArrayAgg { - /// Create a new DistinctArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - nullable: bool, - ) -> Self { - let name = name.into(); - Self { - name, - input_data_type, - expr, - nullable, - } - } -} - -impl AggregateExpr for DistinctArrayAgg { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "distinct_array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - #[derive(Debug)] -struct DistinctArrayAggAccumulator { +pub(crate) struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, } @@ -177,9 +89,10 @@ impl Accumulator for DistinctArrayAggAccumulator { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; + use crate::array_agg::array_agg_udaf; use arrow::array::Int32Array; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; @@ -187,7 +100,10 @@ mod tests { use arrow_array::Array; use arrow_array::ListArray; use arrow_buffer::OffsetBuffer; + use arrow_schema::Field; use datafusion_common::internal_err; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::expressions::column::col; // arrow::compute::sort can't sort nested ListArray directly, so we compare the scalar values pair-wise. fn compare_list_contents( @@ -241,13 +157,31 @@ mod tests { let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?; - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, + let agg = create_aggregate_expr( + &array_agg_udaf(), + &vec![col("a", &schema)?], + &[], + &[], + &[], + &schema, + "array_agg_distinct", + false, true, - )); - let actual = aggregate(&batch, agg)?; + )?; + + let actual = { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| { + e.evaluate(&batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate()? + }; compare_list_contents(expected, actual) } @@ -258,12 +192,18 @@ mod tests { datatype: DataType, ) -> Result<()> { let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, + + let agg = create_aggregate_expr( + &array_agg_udaf(), + &vec![col("a", &schema)?], + &[], + &[], + &[], + &schema, + "array_agg_distinct", + false, true, - )); + )?; let mut accum1 = agg.create_accumulator()?; let mut accum2 = agg.create_accumulator()?; diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/functions-aggregate/src/array_agg_ordered.rs similarity index 82% rename from datafusion/physical-expr/src/aggregate/array_agg_ordered.rs rename to datafusion/functions-aggregate/src/array_agg_ordered.rs index 837a9d551153..b90d3af84cba 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/functions-aggregate/src/array_agg_ordered.rs @@ -18,152 +18,20 @@ //! Defines physical expressions which specify ordering requirement //! that can evaluated at runtime during query execution -use std::any::Any; use std::cmp::Ordering; use std::collections::{BinaryHeap, VecDeque}; use std::fmt::Debug; use std::sync::Arc; -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; -use crate::expressions::format_state_name; -use crate::{ - reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, -}; - -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::DataType; use arrow_array::cast::AsArray; use arrow_array::{new_empty_array, Array, ArrayRef, StructArray}; use arrow_schema::{Fields, SortOptions}; use datafusion_common::utils::{array_into_list_array, compare_rows, get_row_at_idx}; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::Accumulator; - -/// Expression for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi -/// partition setting, partial aggregations are computed for every partition, -/// and then their results are merged. -#[derive(Debug)] -pub struct OrderSensitiveArrayAgg { - /// Column name - name: String, - /// The `DataType` for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// If the input expression can have `NULL`s - nullable: bool, - /// Ordering data types - order_by_data_types: Vec, - /// Ordering requirement - ordering_req: LexOrdering, - /// Whether the aggregation is running in reverse - reverse: bool, -} - -impl OrderSensitiveArrayAgg { - /// Create a new `OrderSensitiveArrayAgg` aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - nullable: bool, - order_by_data_types: Vec, - ordering_req: LexOrdering, - ) -> Self { - Self { - name: name.into(), - input_data_type, - expr, - nullable, - order_by_data_types, - ordering_req, - reverse: false, - } - } -} - -impl AggregateExpr for OrderSensitiveArrayAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - OrderSensitiveArrayAggAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.reverse, - ) - .map(|acc| Box::new(acc) as _) - } - - fn state_fields(&self) -> Result> { - let mut fields = vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, // This should be the same as field() - )]; - let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); - fields.push(Field::new_list( - format_state_name(&self.name, "array_agg_orderings"), - Field::new("item", DataType::Struct(Fields::from(orderings)), true), - self.nullable, - )); - Ok(fields) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } - - fn order_sensitivity(&self) -> AggregateOrderSensitivity { - AggregateOrderSensitivity::HardRequirement - } - - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.to_string(), - input_data_type: self.input_data_type.clone(), - expr: self.expr.clone(), - nullable: self.nullable, - order_by_data_types: self.order_by_data_types.clone(), - // Reverse requirement: - ordering_req: reverse_order_bys(&self.ordering_req), - reverse: !self.reverse, - })) - } -} - -impl PartialEq for OrderSensitiveArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} +use datafusion_physical_expr_common::aggregate::utils::ordering_fields; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; #[derive(Debug)] pub(crate) struct OrderSensitiveArrayAggAccumulator { @@ -544,13 +412,13 @@ mod tests { use std::collections::VecDeque; use std::sync::Arc; - use crate::aggregate::array_agg_ordered::merge_ordered_arrays; - use arrow_array::{Array, ArrayRef, Int64Array}; use arrow_schema::SortOptions; use datafusion_common::utils::get_row_at_idx; use datafusion_common::{Result, ScalarValue}; + use crate::array_agg_ordered::merge_ordered_arrays; + #[test] fn test_merge_asc() -> Result<()> { let lhs_arrays: Vec = vec![ diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 260d6dab31b9..7c6cecc4ea53 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -56,11 +56,15 @@ pub mod macros; pub mod approx_distinct; +pub mod array_agg; +pub mod array_agg_distinct; +pub mod array_agg_ordered; pub mod count; pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod nth_value; pub mod regr; pub mod stddev; pub mod sum; @@ -86,6 +90,7 @@ pub mod expr_fn { pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; + pub use super::array_agg::array_agg; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; pub use super::bit_and_or_xor::bit_xor; @@ -98,6 +103,7 @@ pub mod expr_fn { pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::median::median; + pub use super::nth_value::nth_value; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; @@ -117,6 +123,7 @@ pub mod expr_fn { /// Returns all default aggregate functions pub fn all_default_aggregate_functions() -> Vec> { vec![ + array_agg::array_agg_udaf(), first_last::first_value_udaf(), first_last::last_value_udaf(), covariance::covar_samp_udaf(), @@ -124,6 +131,7 @@ pub fn all_default_aggregate_functions() -> Vec> { covariance::covar_pop_udaf(), median::median_udaf(), count::count_udaf(), + nth_value::nth_value_udaf(), regr::regr_slope_udaf(), regr::regr_intercept_udaf(), regr::regr_count_udaf(), @@ -177,7 +185,11 @@ mod tests { for func in all_default_aggregate_functions() { // TODO: remove this // These functions are in intermidiate migration state, skip them - if func.name().to_lowercase() == "count" { + let name_lower_case = func.name().to_lowercase(); + if name_lower_case == "count" + || name_lower_case == "array_agg" + || name_lower_case == "nth_value" + { continue; } assert!( diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs similarity index 78% rename from datafusion/physical-expr/src/aggregate/nth_value.rs rename to datafusion/functions-aggregate/src/nth_value.rs index ee7426a897b3..6e8b3fb8b770 100644 --- a/datafusion/physical-expr/src/aggregate/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -22,144 +22,136 @@ use std::any::Any; use std::collections::VecDeque; use std::sync::Arc; -use crate::aggregate::array_agg_ordered::merge_ordered_arrays; -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; -use crate::expressions::{format_state_name, Literal}; -use crate::{ - reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, -}; - use arrow_array::cast::AsArray; use arrow_array::{new_empty_array, ArrayRef, StructArray}; use arrow_schema::{DataType, Field, Fields}; + use datafusion_common::utils::{array_into_list_array, get_row_at_idx}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; -use datafusion_expr::utils::AggregateOrderSensitivity; -use datafusion_expr::Accumulator; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; +use datafusion_expr::{ + Accumulator, AggregateUDF, AggregateUDFImpl, Expr, Signature, Volatility, +}; +use datafusion_physical_expr_common::aggregate::utils::ordering_fields; +use datafusion_physical_expr_common::sort_expr::{ + limited_convert_logical_sort_exprs_to_physical, LexOrdering, PhysicalSortExpr, +}; + +use crate::array_agg_ordered::merge_ordered_arrays; + +make_udaf_expr_and_func!( + NthValueAgg, + nth_value, + expression, + "Computes the nth value", + nth_value_udaf +); /// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi /// partition setting, partial aggregations are computed for every partition, /// and then their results are merged. #[derive(Debug)] pub struct NthValueAgg { - /// Column name - name: String, - /// The `DataType` for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// The `N` value. - n: i64, - /// If the input expression can have `NULL`s - nullable: bool, - /// Ordering data types - order_by_data_types: Vec, - /// Ordering requirement - ordering_req: LexOrdering, + signature: Signature, + alias: Vec, + reverse: bool, } -impl NthValueAgg { - /// Create a new `NthValueAgg` aggregate function - pub fn new( - expr: Arc, - n: i64, - name: impl Into, - input_data_type: DataType, - nullable: bool, - order_by_data_types: Vec, - ordering_req: LexOrdering, - ) -> Self { +impl Default for NthValueAgg { + fn default() -> Self { Self { - name: name.into(), - input_data_type, - expr, - n, - nullable, - order_by_data_types, - ordering_req, + signature: Signature::any(2, Volatility::Immutable), + alias: vec!["nth_value".to_string()], + reverse: false, } } } -impl AggregateExpr for NthValueAgg { +impl AggregateUDFImpl for NthValueAgg { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) + fn name(&self) -> &str { + "NTH_VALUE" } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(NthValueAccumulator::try_new( - self.n, - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - )?)) + fn aliases(&self) -> &[String] { + &self.alias + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) } - fn state_fields(&self) -> Result> { + fn state_fields( + &self, + args: datafusion_expr::function::StateFieldsArgs, + ) -> Result> { let mut fields = vec![Field::new_list( - format_state_name(&self.name, "nth_value"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, // This should be the same as field() + format_state_name(args.name, "nth_value"), + Field::new("item", args.input_type.clone(), true), + true, )]; - if !self.ordering_req.is_empty() { - let orderings = - ordering_fields(&self.ordering_req, &self.order_by_data_types); + if !args.ordering_fields.is_empty() { fields.push(Field::new_list( - format_state_name(&self.name, "nth_value_orderings"), - Field::new("item", DataType::Struct(Fields::from(orderings)), true), - self.nullable, + format_state_name(args.name, "nth_value_orderings"), + Field::new( + "item", + DataType::Struct(Fields::from(args.ordering_fields.to_vec())), + true, + ), + true, )); } Ok(fields) } - fn expressions(&self) -> Vec> { - let n = Arc::new(Literal::new(ScalarValue::Int64(Some(self.n)))) as _; - vec![self.expr.clone(), n] - } + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let n = match acc_args.input_exprs[1] { + Expr::Literal(ScalarValue::Int64(Some(n))) => { + if self.reverse { + -n + } else { + n + } + } + _ => return exec_err!("Second argument of NTH_VALUE needs to be a literal"), + }; - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } + let ordering_req = limited_convert_logical_sort_exprs_to_physical( + acc_args.sort_exprs, + acc_args.schema, + )?; - fn order_sensitivity(&self) -> AggregateOrderSensitivity { - AggregateOrderSensitivity::HardRequirement - } + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; - fn name(&self) -> &str { - &self.name + Ok(Box::new(NthValueAccumulator::try_new( + n, + acc_args.input_type, + &ordering_dtypes, + ordering_req, + )?)) } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.to_string(), - input_data_type: self.input_data_type.clone(), - expr: self.expr.clone(), - // index should be from the opposite side - n: -self.n, - nullable: self.nullable, - order_by_data_types: self.order_by_data_types.clone(), - // reverse requirement - ordering_req: reverse_order_bys(&self.ordering_req), - }) as _) + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::HardRequirement } -} -impl PartialEq for NthValueAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(Arc::new(AggregateUDF::from(Self { + signature: self.signature.clone(), + alias: self.alias.clone(), + reverse: !self.reverse, + }))) } } diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs deleted file mode 100644 index a23ba07de44a..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ /dev/null @@ -1,188 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; -use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// ARRAY_AGG aggregate expression -#[derive(Debug)] -pub struct ArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// If the input expression can have NULLs - nullable: bool, -} - -impl ArrayAgg { - /// Create a new ArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - nullable: bool, - ) -> Self { - Self { - name: name.into(), - input_data_type: data_type, - expr, - nullable, - } - } -} - -impl AggregateExpr for ArrayAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(ArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct ArrayAggAccumulator { - values: Vec, - datatype: DataType, -} - -impl ArrayAggAccumulator { - /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: vec![], - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for ArrayAggAccumulator { - // Append value like Int64Array(1,2,3) - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - assert!(values.len() == 1, "array_agg can only take 1 param!"); - let val = values[0].clone(); - self.values.push(val); - Ok(()) - } - - // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert!(states.len() == 1, "array_agg states must be singleton!"); - - let list_arr = as_list_array(&states[0])?; - for arr in list_arr.iter().flatten() { - self.values.push(arr); - } - Ok(()) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - // Transform Vec to ListArr - - let element_arrays: Vec<&dyn Array> = - self.values.iter().map(|a| a.as_ref()).collect(); - - if element_arrays.is_empty() { - let arr = ScalarValue::new_list(&[], &self.datatype); - return Ok(ScalarValue::List(arr)); - } - - let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array(concated_array); - - Ok(ScalarValue::List(Arc::new(list_array))) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|arr| arr.get_array_memory_size()) - .sum::() - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 53cfcfb033a1..4c5d9f99526c 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,11 +30,11 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_common::{not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; -use crate::expressions::{self, Literal}; +use crate::expressions; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; /// Create a physical aggregation expression. @@ -43,7 +43,7 @@ pub fn create_aggregate_expr( fun: &AggregateFunction, distinct: bool, input_phy_exprs: &[Arc], - ordering_req: &[PhysicalSortExpr], + _ordering_req: &[PhysicalSortExpr], input_schema: &Schema, name: impl Into, _ignore_nulls: bool, @@ -55,10 +55,6 @@ pub fn create_aggregate_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; let data_type = input_phy_types[0].clone(); - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(input_schema)) - .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( @@ -66,38 +62,6 @@ pub fn create_aggregate_expr( name, data_type, )), - (AggregateFunction::ArrayAgg, false) => { - let expr = input_phy_exprs[0].clone(); - let nullable = expr.nullable(input_schema)?; - - if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) - } else { - Arc::new(expressions::OrderSensitiveArrayAgg::new( - expr, - name, - data_type, - nullable, - ordering_types, - ordering_req.to_vec(), - )) - } - } - (AggregateFunction::ArrayAgg, true) => { - if !ordering_req.is_empty() { - return not_impl_err!( - "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" - ); - } - let expr = input_phy_exprs[0].clone(); - let is_expr_nullable = expr.nullable(input_schema)?; - Arc::new(expressions::DistinctArrayAgg::new( - expr, - name, - data_type, - is_expr_nullable, - )) - } (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( input_phy_exprs[0].clone(), name, @@ -125,26 +89,6 @@ pub fn create_aggregate_expr( (AggregateFunction::Correlation, true) => { return not_impl_err!("CORR(DISTINCT) aggregations are not available"); } - (AggregateFunction::NthValue, _) => { - let expr = &input_phy_exprs[0]; - let Some(n) = input_phy_exprs[1] - .as_any() - .downcast_ref::() - .map(|literal| literal.value()) - else { - return exec_err!("Second argument of NTH_VALUE needs to be a literal"); - }; - let nullable = expr.nullable(input_schema)?; - Arc::new(expressions::NthValueAgg::new( - expr.clone(), - n.clone().try_into()?, - name, - input_phy_types[0].clone(), - nullable, - ordering_types, - ordering_req.to_vec(), - )) - } }) } @@ -155,70 +99,9 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{try_cast, ArrayAgg, Avg, DistinctArrayAgg, Max, Min}; + use crate::expressions::{try_cast, Avg, Max, Min}; use super::*; - #[test] - fn test_approx_expr() -> Result<()> { - let funcs = vec![AggregateFunction::ArrayAgg]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - - let result_distinct = create_physical_agg_expr_for_test( - &fun, - true, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - } - } - Ok(()) - } #[test] fn test_min_max_expr() -> Result<()> { diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index f64c5b1fb260..b5bb2cdaddc5 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,14 +17,10 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub(crate) mod array_agg; -pub(crate) mod array_agg_distinct; -pub(crate) mod array_agg_ordered; pub(crate) mod average; pub(crate) mod correlation; pub(crate) mod covariance; pub(crate) mod grouping; -pub(crate) mod nth_value; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 0020aa5f55b2..e788809c2df8 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -35,16 +35,12 @@ mod try_cast; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::array_agg::ArrayAgg; -pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; -pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::average::Avg; pub use crate::aggregate::average::AvgAccumulator; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; -pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::stats::StatsType; pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::lead_lag::{lag, lead, WindowShift}; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b7d8d60f4f35..5a7f276b4b9c 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1194,11 +1194,11 @@ mod tests { use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::expr::Sort; + use datafusion_expr::Expr; + use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::median::median_udaf; - use datafusion_physical_expr::expressions::{ - lit, FirstValue, LastValue, OrderSensitiveArrayAgg, - }; + use datafusion_physical_expr::expressions::{lit, FirstValue, LastValue}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::aggregate::create_aggregate_expr; @@ -2145,42 +2145,44 @@ mod tests { let col_a = &col("a", &test_schema)?; let col_b = &col("b", &test_schema)?; let col_c = &col("c", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema); + let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); // Columns a and b are equal. eq_properties.add_equal_conditions(col_a, col_b)?; // Aggregate requirements are // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively - let order_by_exprs = vec![ + let order_by_names = vec![ None, - Some(vec![PhysicalSortExpr { - expr: col_a.clone(), - options: options1, - }]), - Some(vec![ - PhysicalSortExpr { - expr: col_a.clone(), - options: options1, - }, - PhysicalSortExpr { - expr: col_b.clone(), - options: options1, - }, - PhysicalSortExpr { - expr: col_c.clone(), - options: options1, - }, - ]), - Some(vec![ - PhysicalSortExpr { - expr: col_a.clone(), - options: options1, - }, - PhysicalSortExpr { - expr: col_b.clone(), - options: options1, - }, - ]), + Some(vec!["a"]), + Some(vec!["a", "b", "c"]), + Some(vec!["a", "b"]), ]; + let sort_exprs = order_by_names.clone().into_iter().map(|maybe_names| { + maybe_names.map(|names| { + names + .iter() + .map(|name| { + Expr::Sort(Sort { + expr: Box::new(Expr::Column( + datafusion_common::Column::new_unqualified(*name), + )), + asc: true, + nulls_first: false, + }) + }) + .collect::>() + }) + }); + let order_by_exprs = order_by_names.into_iter().map(|maybe_names| { + maybe_names.map(|names| { + names + .into_iter() + .map(|name| PhysicalSortExpr { + expr: col(name, &test_schema).unwrap(), + options: options1, + }) + .collect::>() + }) + }); let common_requirement = vec![ PhysicalSortExpr { expr: col_a.clone(), @@ -2192,16 +2194,20 @@ mod tests { }, ]; let mut aggr_exprs = order_by_exprs - .into_iter() - .map(|order_by_expr| { - Arc::new(OrderSensitiveArrayAgg::new( - col_a.clone(), + .zip(sort_exprs) + .map(|(order_by_expr, sort_expr)| { + create_aggregate_expr( + &array_agg_udaf(), + &[col_a.clone()], + &[], + &sort_expr.unwrap_or_default(), + &order_by_expr.unwrap_or_default(), + &test_schema, "array_agg", - DataType::Int32, false, - vec![], - order_by_expr.unwrap_or_default(), - )) as _ + false, + ) + .unwrap() }) .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 50356d5b6052..d6c102b2d99d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -478,7 +478,7 @@ enum AggregateFunction { AVG = 3; // COUNT = 4; // APPROX_DISTINCT = 5; - ARRAY_AGG = 6; + // ARRAY_AGG = 6; // VARIANCE = 7; // VARIANCE_POP = 8; // COVARIANCE = 9; @@ -506,7 +506,7 @@ enum AggregateFunction { // REGR_SYY = 33; // REGR_SXY = 34; // STRING_AGG = 35; - NTH_VALUE_AGG = 36; + // NTH_VALUE_AGG = 36; } message AggregateExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8cca0fe4a876..60ed7b851141 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -535,10 +535,8 @@ impl serde::Serialize for AggregateFunction { Self::Min => "MIN", Self::Max => "MAX", Self::Avg => "AVG", - Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", Self::Grouping => "GROUPING", - Self::NthValueAgg => "NTH_VALUE_AGG", }; serializer.serialize_str(variant) } @@ -553,10 +551,8 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN", "MAX", "AVG", - "ARRAY_AGG", "CORRELATION", "GROUPING", - "NTH_VALUE_AGG", ]; struct GeneratedVisitor; @@ -600,10 +596,8 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), "AVG" => Ok(AggregateFunction::Avg), - "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), "GROUPING" => Ok(AggregateFunction::Grouping), - "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 56f14982923d..7da7fb1715b5 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1932,7 +1932,7 @@ pub enum AggregateFunction { Avg = 3, /// COUNT = 4; /// APPROX_DISTINCT = 5; - ArrayAgg = 6, + /// ARRAY_AGG = 6; /// VARIANCE = 7; /// VARIANCE_POP = 8; /// COVARIANCE = 9; @@ -1943,7 +1943,7 @@ pub enum AggregateFunction { /// APPROX_PERCENTILE_CONT = 14; /// APPROX_MEDIAN = 15; /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - Grouping = 17, + /// /// MEDIAN = 18; /// BIT_AND = 19; /// BIT_OR = 20; @@ -1960,7 +1960,8 @@ pub enum AggregateFunction { /// REGR_SYY = 33; /// REGR_SXY = 34; /// STRING_AGG = 35; - NthValueAgg = 36, + /// NTH_VALUE_AGG = 36; + Grouping = 17, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -1972,10 +1973,8 @@ impl AggregateFunction { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", AggregateFunction::Avg => "AVG", - AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::Grouping => "GROUPING", - AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -1984,10 +1983,8 @@ impl AggregateFunction { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), "AVG" => Some(Self::Avg), - "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), "GROUPING" => Some(Self::Grouping), - "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ba0e708218cf..951a7004bb8d 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -140,10 +140,8 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Min => Self::Min, protobuf::AggregateFunction::Max => Self::Max, protobuf::AggregateFunction::Avg => Self::Avg, - protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::Grouping => Self::Grouping, - protobuf::AggregateFunction::NthValueAgg => Self::NthValue, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 08999effa4b1..0c2d4ffb66b0 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -111,10 +111,8 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Min => Self::Min, AggregateFunction::Max => Self::Max, AggregateFunction::Avg => Self::Avg, - AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, AggregateFunction::Grouping => Self::Grouping, - AggregateFunction::NthValue => Self::NthValueAgg, } } } @@ -371,7 +369,6 @@ pub fn serialize_expr( }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, @@ -379,9 +376,6 @@ pub fn serialize_expr( protobuf::AggregateFunction::Correlation } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::NthValue => { - protobuf::AggregateFunction::NthValueAgg - } }; let aggregate_expr = protobuf::AggregateExprNode { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index a9d3736dee08..1147f9d6ef72 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,9 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, - DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, - NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, - RankType, RowNumber, TryCastExpr, WindowShift, + Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, Grouping, + InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, + NthValue, Ntile, Rank, RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -236,17 +235,10 @@ struct AggrFn { fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); - let mut distinct = false; + let distinct = false; let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Grouping - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Min } else if aggr_expr.downcast_ref::().is_some() { @@ -255,8 +247,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::NthValueAgg } else { return not_impl_err!("Aggregate function not supported: {expr:?}"); }; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index eb3313239544..758c62ff19b6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -22,6 +22,7 @@ use std::vec; use arrow::csv::WriterBuilder; use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion_functions_aggregate::nth_value::NthValueAgg; use prost::Message; use datafusion::arrow::array::ArrayRef; @@ -38,7 +39,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; -use datafusion::physical_expr::expressions::{Max, NthValueAgg}; +use datafusion::physical_expr::expressions::Max; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -348,15 +349,20 @@ fn rountrip_aggregate() -> Result<()> { DataType::Float64, ))], // NTH_VALUE - vec![Arc::new(NthValueAgg::new( - col("b", &schema)?, - 1, - "NTH_VALUE(b, 1)".to_string(), - DataType::Int64, + vec![udaf::create_aggregate_expr( + &AggregateUDF::new_from_impl(NthValueAgg::default()), + &[ + cast(col("b", &schema)?, &schema, DataType::Utf8)?, + lit(ScalarValue::Int64(Some(1))), + ], + &[], + &[], + &[], + &schema, + "NTH_VALUE(b, 1)", false, - Vec::new(), - Vec::new(), - ))], + false, + )?], // STRING_AGG vec![udaf::create_aggregate_expr( &AggregateUDF::new_from_impl(StringAgg::new()), diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index c4ae3a8134a6..11bb12645859 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -46,6 +46,7 @@ arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate= { workspace = true } log = { workspace = true } regex = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 8b64ccfb52cb..13359a702d9a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,6 +17,9 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; +use datafusion_expr::expr::AggregateFunctionDefinition; +use datafusion_functions_aggregate::array_agg::ArrayAgg; +use datafusion_functions_aggregate::nth_value::nth_value_udaf; use sqlparser::ast::{CastKind, Expr as SQLExpr, Subscript, TrimWhereField, Value}; use datafusion_common::{ @@ -26,8 +29,8 @@ use datafusion_common::{ use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, - GetFieldAccess, Like, Literal, Operator, TryCast, + lit, Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, + Operator, TryCast, }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -596,18 +599,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// TODO: this should likely be done in ArrayAgg::simplify when it is moved to a UDAF fn simplify_array_index_expr(expr: &Expr, index: &Expr) -> Option { fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( - AggregateFunction::ArrayAgg, - ) + match agg_func.func_def { + AggregateFunctionDefinition::UDF(ref udf) => { + udf.inner().as_any().downcast_ref::().is_some() + } + _ => false, + } } match expr { Expr::AggregateFunction(agg_func) if is_array_agg(agg_func) => { let mut new_args = agg_func.args.clone(); new_args.push(index.clone()); Some(Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new( - AggregateFunction::NthValue, + datafusion_expr::expr::AggregateFunction::new_udf( + nth_value_udaf(), new_args, agg_func.distinct, agg_func.filter.clone(), diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index d51c69496d46..d0a387808016 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -124,7 +124,7 @@ from aggregate_test_100 order by c9 # WindowFunction with BuiltInWindowFunction wrong signature -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'NTH_VALUE\(Int32, Int64, Int64\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tNTH_VALUE\(Any, Any\) +statement error No function matches the given name and argument types select c9, nth_value(c5, 2, 3) over (order by c9) as nv1 diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 8ccf3ae85345..c9fb1078910f 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4974,7 +4974,7 @@ logical_plan 02)--Aggregate: groupBy=[[multiple_ordered_table.a, multiple_ordered_table.b]], aggr=[[ARRAY_AGG(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] 03)----TableScan: multiple_ordered_table projection=[a, b, c] physical_plan -01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[ARRAY_AGG(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]], ordering_mode=Sorted +01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[ARRAY_AGG(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true query II?