diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index e3d2e6555d5c..5899cc927703 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -33,8 +33,6 @@ use strum_macros::EnumIter; // https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// Count - Count, /// Minimum Min, /// Maximum @@ -89,7 +87,6 @@ impl AggregateFunction { pub fn name(&self) -> &str { use AggregateFunction::*; match self { - Count => "COUNT", Min => "MIN", Max => "MAX", Avg => "AVG", @@ -135,7 +132,6 @@ impl FromStr for AggregateFunction { "bit_xor" => AggregateFunction::BitXor, "bool_and" => AggregateFunction::BoolAnd, "bool_or" => AggregateFunction::BoolOr, - "count" => AggregateFunction::Count, "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, "min" => AggregateFunction::Min, @@ -190,7 +186,6 @@ impl AggregateFunction { })?; match self { - AggregateFunction::Count => Ok(DataType::Int64), AggregateFunction::Max | AggregateFunction::Min => { // For min and max agg function, the returned type is same as input type. // The coerced_data_types is same with input_types. @@ -249,7 +244,6 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable), AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { Signature::any(1, Volatility::Immutable) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 57f5414c13bd..9ba866a4c919 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2135,18 +2135,6 @@ mod test { use super::*; - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); @@ -2250,7 +2238,6 @@ mod test { "nth_value", "min", "max", - "count", "avg", ]; for name in names { diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index ab7deaff9885..2c76407cdfe2 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -96,7 +96,6 @@ pub fn coerce_types( check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; match agg_fun { - AggregateFunction::Count => Ok(input_types.to_vec()), AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), AggregateFunction::Min | AggregateFunction::Max => { // min and max support the dictionary data type @@ -525,7 +524,6 @@ mod tests { // test count, array_agg, approx_distinct, min, max. // the coerced types is same with input types let funs = vec![ - AggregateFunction::Count, AggregateFunction::ArrayAgg, AggregateFunction::Min, AggregateFunction::Max, diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index a87ec632ba9e..de2af520053a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -25,9 +25,7 @@ use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionDefinition, WindowFunction, }; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{ - aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition, -}; +use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -56,37 +54,19 @@ fn is_wildcard(expr: &Expr) -> bool { } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { - match aggregate_function { + matches!(aggregate_function, AggregateFunction { func_def: AggregateFunctionDefinition::UDF(udf), args, .. - } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true, - AggregateFunction { - func_def: - AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ), - args, - .. - } if args.len() == 1 && is_wildcard(&args[0]) => true, - _ => false, - } + } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { let args = &window_function.args; - match window_function.fun { - WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ) if args.len() == 1 && is_wildcard(&args[0]) => true, + matches!(window_function.fun, WindowFunctionDefinition::AggregateUDF(ref udaf) - if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => - { - true - } - _ => false, - } + if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn analyze_internal(plan: LogicalPlan) -> Result> { @@ -123,9 +103,10 @@ mod tests { use datafusion_expr::expr::Sort; use datafusion_expr::{ col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, AggregateFunction, WindowFrame, - WindowFrameBound, WindowFrameUnits, + out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; + use datafusion_functions_aggregate::count::count_udaf; use std::sync::Arc; use datafusion_functions_aggregate::expr_fn::{count, sum}; @@ -240,7 +221,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index e14ee763a3c0..e949e1921b97 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -432,14 +432,8 @@ fn agg_exprs_evaluation_result_on_empty_batch( Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( - 0, - )))) - } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } + AggregateFunctionDefinition::BuiltIn(_fun) => { + Transformed::yes(Expr::Literal(ScalarValue::Null)) } AggregateFunctionDefinition::UDF(fun) => { if fun.name() == "COUNT" { diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index e147aa8d6be7..11540d3e162e 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -818,6 +818,7 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; + use datafusion_expr::AggregateExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, @@ -830,6 +831,7 @@ mod tests { WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::count; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -1888,16 +1890,10 @@ mod tests { #[test] fn aggregate_filter_pushdown() -> Result<()> { let table_scan = test_table_scan()?; - - let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("b")], - false, - Some(Box::new(col("c").gt(lit(42)))), - None, - None, - )); - + let aggr_with_filter = count_udaf() + .call(vec![col("b")]) + .filter(col("c").gt(lit(42))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index de885f5a33a8..d3d22eb53f39 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -362,9 +362,11 @@ mod tests { use super::*; use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; + use datafusion_expr::AggregateExt; use datafusion_expr::{ lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum}; use datafusion_functions_aggregate::sum::sum_udaf; @@ -679,14 +681,11 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - None, - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; @@ -725,19 +724,16 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - None, - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -748,19 +744,17 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index b3501cca9efa..618e465b6de9 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; @@ -323,7 +324,7 @@ fn test_sql(sql: &str) -> Result { let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); let statement = &ast[0]; - let context_provider = MyContextProvider::default().with_udaf(sum_udaf()); + let context_provider = MyContextProvider::default().with_udaf(sum_udaf()).with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -345,7 +346,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index aee7bca3b88f..75f2e12320bf 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; @@ -61,9 +61,6 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Count, _) => { - return internal_err!("Builtin Count will be removed"); - } (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, @@ -642,20 +639,6 @@ mod tests { Ok(()) } - #[test] - fn test_count_return_type() -> Result<()> { - let observed = AggregateFunction::Count.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = AggregateFunction::Count.return_type(&[DataType::Int8])?; - assert_eq!(DataType::Int64, observed); - - let observed = - AggregateFunction::Count.return_type(&[DataType::Decimal128(28, 13)])?; - assert_eq!(DataType::Int64, observed); - Ok(()) - } - #[test] fn test_avg_return_type() -> Result<()> { let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index b1897aa58e7d..aa8d0e55b68f 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -59,6 +59,7 @@ serde_json = { workspace = true, optional = true } [dev-dependencies] datafusion-functions = { workspace = true, default-features = true } +datafusion-functions-aggregate = { workspace = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2bb3ec793d7f..31cb0d1da9d5 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -476,7 +476,7 @@ enum AggregateFunction { MAX = 1; // SUM = 2; AVG = 3; - COUNT = 4; + // COUNT = 4; // APPROX_DISTINCT = 5; ARRAY_AGG = 6; // VARIANCE = 7; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 59b7861a6ef1..503f83af65f2 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -535,7 +535,6 @@ impl serde::Serialize for AggregateFunction { Self::Min => "MIN", Self::Max => "MAX", Self::Avg => "AVG", - Self::Count => "COUNT", Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", @@ -571,7 +570,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN", "MAX", "AVG", - "COUNT", "ARRAY_AGG", "CORRELATION", "APPROX_PERCENTILE_CONT", @@ -636,7 +634,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), "AVG" => Ok(AggregateFunction::Avg), - "COUNT" => Ok(AggregateFunction::Count), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0861c287fcfa..2c0ea62466b4 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1930,7 +1930,7 @@ pub enum AggregateFunction { Max = 1, /// SUM = 2; Avg = 3, - Count = 4, + /// COUNT = 4; /// APPROX_DISTINCT = 5; ArrayAgg = 6, /// VARIANCE = 7; @@ -1972,7 +1972,6 @@ impl AggregateFunction { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", AggregateFunction::Avg => "AVG", - AggregateFunction::Count => "COUNT", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", @@ -2004,7 +2003,6 @@ impl AggregateFunction { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), "AVG" => Some(Self::Avg), - "COUNT" => Some(Self::Count), "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2ad40d883fe6..54a59485c836 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -145,7 +145,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::BitXor => Self::BitXor, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, - protobuf::AggregateFunction::Count => Self::Count, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6a275ed7a1b8..80ce05d151ee 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -116,7 +116,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::BitXor => Self::BitXor, AggregateFunction::BoolAnd => Self::BoolAnd, AggregateFunction::BoolOr => Self::BoolOr, - AggregateFunction::Count => Self::Count, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, AggregateFunction::RegrSlope => Self::RegrSlope, @@ -406,7 +405,6 @@ pub fn serialize_expr( AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2ed59ee39f75..d0f1c4aade5e 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,6 +26,7 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use datafusion_functions_aggregate::count::count_udaf; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -53,10 +54,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateFunction, ColumnarValue, ExprSchemable, LogicalPlan, Operator, - PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, + Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable, + LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, + TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1782,28 +1783,18 @@ fn roundtrip_similar_to() { #[test] fn roundtrip_count() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - false, - None, - None, - None, - )); + let test_expr = count(col("bananas")); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } #[test] fn roundtrip_count_distinct() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - true, - None, - None, - None, - )); + let test_expr = count_udaf() + .call(vec![col("bananas")]) + .distinct() + .build() + .unwrap(); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 893db018c8af..a4ef093be349 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -66,7 +66,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index dc25a6c33ece..12c48054f1a7 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -960,13 +960,14 @@ mod tests { use arrow_schema::DataType::Int8; use datafusion_common::TableReference; + use datafusion_expr::AggregateExt; use datafusion_expr::{ - case, col, cube, exists, - expr::{AggregateFunction, AggregateFunctionDefinition}, - grouping_set, lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, - try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, WindowFrame, WindowFunctionDefinition, + case, col, cube, exists, grouping_set, lit, not, not_exists, out_ref_col, + placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, + WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; use crate::unparser::dialect::CustomDialect; @@ -1127,29 +1128,19 @@ mod tests { ), (sum(col("a")), r#"sum(a)"#), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: true, - filter: None, - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .distinct() + .build() + .unwrap(), "COUNT(DISTINCT *)", ), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: false, - filter: Some(Box::new(lit(true))), - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .filter(lit(true)) + .build() + .unwrap(), "COUNT(*) FILTER (WHERE true)", ), ( @@ -1167,9 +1158,7 @@ mod tests { ), ( Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::AggregateFunction( - datafusion_expr::AggregateFunction::Count, - ), + fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], order_by: vec![Expr::Sort(Sort::new( diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 72018371a5f1..5d726d4700cf 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -19,7 +19,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::{DFSchema, Result, TableReference}; -use datafusion_expr::test::function_stub::sum_udaf; +use datafusion_expr::test::function_stub::{count_udaf, sum_udaf}; use datafusion_expr::{col, table_scan}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -153,7 +153,7 @@ fn roundtrip_statement() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default().with_udaf(sum_udaf()); + let context = MockContextProvider::default().with_udaf(sum_udaf()).with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index d91c09ae1287..893678d6b374 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -46,7 +46,8 @@ impl MockContextProvider { } pub(crate) fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 7b9d39a2b51e..4b4279ec4157 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -37,7 +37,7 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use datafusion_functions_aggregate::approx_median::approx_median_udaf; +use datafusion_functions_aggregate::{approx_median::approx_median_udaf, count::count_udaf}; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -2702,7 +2702,8 @@ fn logical_plan_with_dialect_and_options( )) .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64)) .with_udaf(sum_udaf()) - .with_udaf(approx_median_udaf()); + .with_udaf(approx_median_udaf()) + .with_udaf(count_udaf()); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect);