diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index fd471e750194..b18ac57cd1c3 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -980,6 +980,16 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "concat-idents" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f76990911f2267d837d9d0ad060aa63aaad170af40904b29461734c339030d4d" +dependencies = [ + "quote", + "syn 2.0.61", +] + [[package]] name = "const-random" version = "0.1.18" @@ -1286,6 +1296,7 @@ name = "datafusion-functions-aggregate" version = "38.0.0" dependencies = [ "arrow", + "concat-idents", "datafusion-common", "datafusion-execution", "datafusion-expr", diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index f97647565364..ef09273dfa50 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -39,6 +39,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } +concat-idents = "1.1.5" datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } diff --git a/datafusion/functions-aggregate/src/expr_builder.rs b/datafusion/functions-aggregate/src/expr_builder.rs new file mode 100644 index 000000000000..594203ae7497 --- /dev/null +++ b/datafusion/functions-aggregate/src/expr_builder.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_expr::{expr::AggregateFunction, Expr}; +use sqlparser::ast::NullTreatment; + +pub struct ExprBuilder { + udf: Arc, + /// List of expressions to feed to the functions as arguments + args: Vec, + /// Whether this is a DISTINCT aggregation or not + distinct: bool, + /// Optional filter + filter: Option>, + /// Optional ordering + order_by: Option>, + null_treatment: Option, +} + +impl ExprBuilder { + pub fn new(udf: Arc, args: Vec) -> Self { + Self { + udf, + args, + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + } + } + + pub fn new_distinct(udf: Arc, args: Vec) -> Self { + Self { + udf, + args, + distinct: true, + filter: None, + order_by: None, + null_treatment: None, + } + } +} + +impl ExprBuilder { + pub fn build(self) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + self.udf, + self.args, + self.distinct, + self.filter, + self.order_by, + self.null_treatment, + )) + } + + pub fn order_by(mut self, order_by: Vec) -> Self { + self.order_by = Some(order_by); + self + } + + pub fn filter(mut self, filter: Box) -> Self { + self.filter = Some(filter); + self + } + + pub fn null_treatment(mut self, null_treatment: NullTreatment) -> Self { + self.null_treatment = Some(null_treatment); + self + } +} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 5d3d48344014..2bf4daf4f0f9 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -47,6 +47,8 @@ use std::sync::Arc; make_udaf_expr_and_func!( FirstValue, first_value, + expression, + order_by, "Returns the first value in a group of values.", first_value_udaf ); diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index e76a43e39899..0849f0713026 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -56,6 +56,7 @@ pub mod macros; pub mod covariance; +pub mod expr_builder; pub mod first_last; use datafusion_common::Result; @@ -66,8 +67,13 @@ use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::covariance::covar_pop; pub use super::covariance::covar_samp; pub use super::first_last::first_value; + + pub use super::covariance::covar_pop_builder; + pub use super::covariance::covar_samp_builder; + pub use super::first_last::first_value_builder; } /// Registers all enabled packages with a [`FunctionRegistry`] diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 6c3348d6c1d6..1f4d2ae8425d 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -48,24 +48,40 @@ macro_rules! make_udaf_expr_and_func { None, )) } + + create_builder!( + $EXPR_FN, + $($arg)*, + $DOC, + $AGGREGATE_UDF_FN + ); + create_func!($UDAF, $AGGREGATE_UDF_FN); }; - ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $distinct:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $order_by: ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN( $($arg: datafusion_expr::Expr,)* - distinct: bool, + order_by: Option>, ) -> datafusion_expr::Expr { datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), vec![$($arg),*], - distinct, + false, None, + order_by, None, - None )) } + + create_builder!( + $EXPR_FN, + $($arg)*, + $DOC, + $AGGREGATE_UDF_FN + ); + create_func!($UDAF, $AGGREGATE_UDF_FN); }; ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { @@ -73,24 +89,57 @@ macro_rules! make_udaf_expr_and_func { #[doc = $DOC] pub fn $EXPR_FN( args: Vec, - distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option ) -> datafusion_expr::Expr { datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( $AGGREGATE_UDF_FN(), args, - distinct, - filter, - order_by, - null_treatment, + false, + None, + None, + None, )) } + + create_builder!( + $EXPR_FN, + $DOC, + $AGGREGATE_UDF_FN + ); + create_func!($UDAF, $AGGREGATE_UDF_FN); }; } +macro_rules! create_builder { + ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + concat_idents::concat_idents!(builder_fn_name = $EXPR_FN, _builder { + #[doc = $DOC] + pub fn builder_fn_name( + $($arg: datafusion_expr::Expr,)* + ) -> crate::expr_builder::ExprBuilder { + crate::expr_builder::ExprBuilder::new( + $AGGREGATE_UDF_FN(), + vec![$($arg),*], + ) + } + }); + }; + + ($EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + concat_idents::concat_idents!(builder_fn_name = $EXPR_FN, _builder { + #[doc = $DOC] + pub fn builder_fn_name( + args: Vec, + ) -> crate::expr_builder::ExprBuilder { + crate::expr_builder::ExprBuilder::new( + $AGGREGATE_UDF_FN(), + args, + ) + } + }); + }; +} + macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { paste::paste! { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b5b0b4c2247a..e43193f0c43c 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -31,8 +31,10 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; -use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; -use datafusion::functions_aggregate::expr_fn::first_value; +use datafusion::functions_aggregate::expr_fn::{ + covar_pop, covar_pop_builder, covar_samp, covar_samp_builder, first_value, + first_value_builder, +}; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -621,9 +623,12 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), - first_value(vec![lit(1)], false, None, None, None), + first_value(lit(1), Some(vec![lit(2)])), + first_value_builder(lit(1)).order_by(vec![lit(3)]).build(), covar_samp(lit(1.5), lit(2.2)), + covar_samp_builder(lit(1.5), lit(2.3)).build(), covar_pop(lit(1.5), lit(2.2)), + covar_pop_builder(lit(1.5), lit(2.3)).build(), ]; // ensure expressions created with the expr api can be round tripped