From c6ba32b18d67b9aaae4512160fd90b6fb6fda7ee Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 13 May 2024 21:56:42 +0800 Subject: [PATCH] fix with args Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 3 ++ datafusion-examples/examples/advanced_udaf.rs | 18 +++---- .../examples/simplify_udaf_expression.rs | 16 +++--- .../user_defined/user_defined_aggregates.rs | 10 ++-- datafusion/expr/src/expr_fn.rs | 10 +--- datafusion/expr/src/function.rs | 17 +++++- datafusion/expr/src/udaf.rs | 53 +++++++------------ datafusion/functions-aggregate/src/count.rs | 24 ++++----- .../functions-aggregate/src/covariance.rs | 26 +++------ .../functions-aggregate/src/first_last.rs | 17 ++---- .../simplify_expressions/expr_simplifier.rs | 9 +++- datafusion/physical-expr-common/Cargo.toml | 2 +- .../physical-expr-common/src/aggregate/mod.rs | 25 +++++---- .../tests/cases/roundtrip_logical_plan.rs | 2 +- .../tests/cases/roundtrip_physical_plan.rs | 1 + 15 files changed, 112 insertions(+), 121 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index fd471e750194..8cb37b0d11b8 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1206,6 +1206,7 @@ dependencies = [ "arrow-schema", "chrono", "half", + "hashbrown 0.14.5", "instant", "libc", "num_cpus", @@ -1363,9 +1364,11 @@ dependencies = [ name = "datafusion-physical-expr-common" version = "38.0.0" dependencies = [ + "ahash", "arrow", "datafusion-common", "datafusion-expr", + "hashbrown 0.14.5", ] [[package]] diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 342a23b6e73d..4b8ca57e4b40 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -31,8 +31,8 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl, - GroupsAccumulator, Signature, + function::{AccumulatorArgs, GroupsAccumulatorSupportedArgs, StateFieldsArgs}, + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, }; /// This example shows how to use the full AggregateUDFImpl API to implement a user @@ -92,21 +92,19 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields( - &self, - _name: &str, - value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", value_type, true), + Field::new("prod", args.return_type.clone(), true), Field::new("n", DataType::UInt32, true), ]) } /// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` /// which is used for cases when there are grouping columns in the query - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported( + &self, + _args: GroupsAccumulatorSupportedArgs, + ) -> bool { true } diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index 92deb20272e4..8dcd6fb6424d 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -17,7 +17,9 @@ use arrow_schema::{Field, Schema}; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use datafusion_expr::function::AggregateFunctionSimplification; +use datafusion_expr::function::{ + AggregateFunctionSimplification, GroupsAccumulatorSupportedArgs, StateFieldsArgs, +}; use datafusion_expr::simplify::SimplifyInfo; use std::{any::Any, sync::Arc}; @@ -70,16 +72,14 @@ impl AggregateUDFImpl for BetterAvgUdaf { unimplemented!("should not be invoked") } - fn state_fields( - &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported( + &self, + _args: GroupsAccumulatorSupportedArgs, + ) -> bool { true } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 8f02fb30b013..80526d812b92 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -45,8 +45,9 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - SimpleAggregateUDF, + create_udaf, + function::{AccumulatorArgs, GroupsAccumulatorSupportedArgs}, + AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -725,7 +726,10 @@ impl AggregateUDFImpl for TestGroupsAccumulator { panic!("accumulator shouldn't invoke"); } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported( + &self, + _args: GroupsAccumulatorSupportedArgs, + ) -> bool { true } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index b43f9f4bffbd..d2bd3b7d1f2e 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -23,6 +23,7 @@ use crate::expr::{ }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, + StateFieldsArgs, }; use crate::{ aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, @@ -692,14 +693,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_fields( - &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, - _is_distinct: bool, - _input_type: DataType, - ) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(self.state_fields.clone()) } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index c1be4e6cab07..466ffceeca8a 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -19,7 +19,7 @@ use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; use std::sync::Arc; @@ -91,6 +91,21 @@ impl<'a> AccumulatorArgs<'a> { } } +/// [`GroupsAccumulatorSupportedArgs`] contains information to determine if an +/// aggregate function supports the groups accumulator. +pub struct GroupsAccumulatorSupportedArgs { + pub args_num: usize, + pub is_distinct: bool, +} + +pub struct StateFieldsArgs<'a> { + pub name: &'a str, + pub input_type: &'a DataType, + pub return_type: &'a DataType, + pub ordering_fields: &'a [Field], + pub is_distinct: bool, +} + /// Factory that returns an accumulator for the given aggregate function. pub type AccumulatorFactoryFunction = Arc Result> + Send + Sync>; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 4b9be9e69eab..c9f7125f8c7e 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,7 +17,10 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::function::{AccumulatorArgs, AggregateFunctionSimplification}; +use crate::function::{ + AccumulatorArgs, AggregateFunctionSimplification, GroupsAccumulatorSupportedArgs, + StateFieldsArgs, +}; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::{Accumulator, Expr}; @@ -177,31 +180,16 @@ impl AggregateUDF { /// for more details. /// /// This is used to support multi-phase aggregations - pub fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - is_distinct: bool, - input_type: DataType, - ) -> Result> { - self.inner.state_fields( - name, - value_type, - ordering_fields, - is_distinct, - input_type, - ) + pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.inner.state_fields(args) } /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. pub fn groups_accumulator_supported( &self, - args_num: usize, - is_distinct: bool, + args: GroupsAccumulatorSupportedArgs, ) -> bool { - self.inner - .groups_accumulator_supported(args_num, is_distinct) + self.inner.groups_accumulator_supported(args) } /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details. @@ -338,21 +326,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// The name of the fields must be unique within the query and thus should /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - _is_distinct: bool, - _input_type: DataType, - ) -> Result> { - let value_fields = vec![Field::new( - format_state_name(name, "value"), - value_type, + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![Field::new( + format_state_name(args.name, "value"), + args.return_type.clone(), true, )]; - Ok(value_fields.into_iter().chain(ordering_fields).collect()) + Ok(fields + .into_iter() + .chain(args.ordering_fields.to_vec()) + .collect()) } /// If the aggregate expression has a specialized @@ -365,7 +349,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// `Self::accumulator` for certain queries, such as when this aggregate is /// used as a window function or when there no GROUP BY columns in the /// query. - fn groups_accumulator_supported(&self, _args_num: usize, _is_distinct: bool) -> bool { + fn groups_accumulator_supported( + &self, + _args: GroupsAccumulatorSupportedArgs, + ) -> bool { false } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 7907798820f9..53369a7775bc 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -40,6 +40,7 @@ use datafusion_common::{ downcast_value, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::function::{GroupsAccumulatorSupportedArgs, StateFieldsArgs}; use datafusion_expr::Expr; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, @@ -120,23 +121,16 @@ impl AggregateUDFImpl for Count { Ok(DataType::Int64) } - fn state_fields( - &self, - name: &str, - _value_type: DataType, - _ordering_fields: Vec, - is_distinct: bool, - input_type: DataType, - ) -> Result> { - if is_distinct { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { Ok(vec![Field::new_list( - format_state_name(name, "count distinct"), - Field::new("item", input_type, true), + format_state_name(args.name, "count distinct"), + Field::new("item", args.input_type.clone(), true), false, )]) } else { Ok(vec![Field::new( - format_state_name(name, "count"), + format_state_name(args.name, "count"), DataType::Int64, true, )]) @@ -258,13 +252,13 @@ impl AggregateUDFImpl for Count { &self.aliases } - fn groups_accumulator_supported(&self, args_num: usize, is_distinct: bool) -> bool { + fn groups_accumulator_supported(&self, args: GroupsAccumulatorSupportedArgs) -> bool { // groups accumulator only supports `COUNT(c1)`, not // `COUNT(c1, c2)`, etc - if is_distinct { + if args.is_distinct { return false; } - args_num == 1 + args.args_num == 1 } fn create_groups_accumulator(&self) -> Result> { diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index 518a4f7bebc0..6f03b256fd9f 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -30,8 +30,10 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{ - function::AccumulatorArgs, type_coercion::aggregates::NUMERICS, - utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, + function::{AccumulatorArgs, StateFieldsArgs}, + type_coercion::aggregates::NUMERICS, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Signature, Volatility, }; use datafusion_physical_expr_common::aggregate::stats::StatsType; @@ -101,14 +103,8 @@ impl AggregateUDFImpl for CovarianceSample { Ok(DataType::Float64) } - fn state_fields( - &self, - name: &str, - _value_type: DataType, - _ordering_fields: Vec, - _is_distinct: bool, - _input_type: DataType, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean1"), DataType::Float64, true), @@ -178,14 +174,8 @@ impl AggregateUDFImpl for CovariancePopulation { Ok(DataType::Float64) } - fn state_fields( - &self, - name: &str, - _value_type: DataType, - _ordering_fields: Vec, - _is_distinct: bool, - _input_type: DataType, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean1"), DataType::Float64, true), diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 833b608202b4..5d3d48344014 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -147,20 +147,13 @@ impl AggregateUDFImpl for FirstValue { .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - _is_distinct: bool, - _input_type: DataType, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( - format_state_name(name, "first_value"), - value_type, + format_state_name(args.name, "first_value"), + args.return_type.clone(), true, )]; - fields.extend(ordering_fields); + fields.extend(args.ordering_fields.to_vec()); fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 55052542a8bf..f71516b8d71b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1759,7 +1759,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { mod tests { use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ - function::AggregateFunctionSimplification, interval_arithmetic::Interval, *, + function::{AggregateFunctionSimplification, GroupsAccumulatorSupportedArgs}, + interval_arithmetic::Interval, + *, }; use std::{ collections::HashMap, @@ -3783,7 +3785,10 @@ mod tests { unimplemented!("not needed for tests") } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported( + &self, + _args: GroupsAccumulatorSupportedArgs, + ) -> bool { unimplemented!("not needed for testing") } diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index 987cf1a7a7c4..41e38ef28cdb 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -42,4 +42,4 @@ ahash = { version = "0.8", default-features = false, features = [ arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } -hashbrown = { version = "0.14", features = ["raw"] } \ No newline at end of file +hashbrown = { version = "0.14", features = ["raw"] } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 3cd761c58d47..69fec4c2057c 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -22,6 +22,7 @@ pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::function::{GroupsAccumulatorSupportedArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::check_arg_count; use datafusion_expr::ReversedUDAF; use datafusion_expr::{ @@ -195,13 +196,15 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - self.fun.state_fields( - self.name(), - self.data_type.clone(), - self.ordering_fields.clone(), - self.is_distinct, - self.input_type.clone(), - ) + let args = StateFieldsArgs { + name: self.name(), + input_type: &self.input_type, + return_type: &self.data_type, + ordering_fields: &self.ordering_fields, + is_distinct: self.is_distinct, + }; + + self.fun.state_fields(args) } fn field(&self) -> Result { @@ -281,8 +284,12 @@ impl AggregateExpr for AggregateFunctionExpr { } fn groups_accumulator_supported(&self) -> bool { - self.fun - .groups_accumulator_supported(self.args.len(), self.is_distinct) + let args = GroupsAccumulatorSupportedArgs { + args_num: self.args.len(), + is_distinct: self.is_distinct, + }; + + self.fun.groups_accumulator_supported(args) } fn create_groups_accumulator(&self) -> Result> { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index cd31c58a4b08..e78d5985c080 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -25,7 +25,7 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; -use datafusion::functions_aggregate::expr_fn::{self, first_value}; +use datafusion::functions_aggregate::expr_fn::first_value; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index c2018352c7cf..30a28081edff 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -426,6 +426,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { &schema, "example_agg", false, + false, )?]; roundtrip_test_with_context(