diff --git a/Cargo.toml b/Cargo.toml index ca34ea9c2a24..3af3db6f6626 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ members = [ "datafusion/core", "datafusion/expr", "datafusion/execution", + "datafusion/functions-aggregate", "datafusion/functions", "datafusion/functions-array", "datafusion/optimizer", @@ -78,6 +79,7 @@ datafusion-common-runtime = { path = "datafusion/common-runtime", version = "37. datafusion-execution = { path = "datafusion/execution", version = "37.0.0" } datafusion-expr = { path = "datafusion/expr", version = "37.0.0" } datafusion-functions = { path = "datafusion/functions", version = "37.0.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "37.0.0" } datafusion-functions-array = { path = "datafusion/functions-array", version = "37.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "37.0.0", default-features = false } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "37.0.0", default-features = false } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index d744a891c6a6..a0a7b20ac40e 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1135,6 +1135,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-aggregate", "datafusion-functions-array", "datafusion-optimizer", "datafusion-physical-expr", @@ -1279,6 +1280,19 @@ dependencies = [ "uuid", ] +[[package]] +name = "datafusion-functions-aggregate" +version = "37.0.0" +dependencies = [ + "arrow", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr-common", + "log", + "paste", +] + [[package]] name = "datafusion-functions-array" version = "37.0.0" @@ -1331,6 +1345,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate", "datafusion-physical-expr-common", "half", "hashbrown 0.14.3", @@ -1370,7 +1385,9 @@ dependencies = [ "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate", "datafusion-physical-expr", + "datafusion-physical-expr-common", "futures", "half", "hashbrown 0.14.3", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 610784f91dec..4f18cb5cb74d 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -98,6 +98,7 @@ datafusion-common-runtime = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-functions-array = { workspace = true, optional = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 1a582be3013d..8fc60770105b 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -44,6 +44,7 @@ use crate::{ datasource::{provider_as_source, MemTable, TableProvider, ViewTable}, error::{DataFusionError, Result}, execution::{options::ArrowReadOptions, runtime_env::RuntimeEnv, FunctionRegistry}, + logical_expr::AggregateUDF, logical_expr::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, @@ -53,10 +54,11 @@ use crate::{ optimizer::analyzer::{Analyzer, AnalyzerRule}, optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule}, physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule}, - physical_plan::{udaf::AggregateUDF, udf::ScalarUDF, ExecutionPlan}, + physical_plan::{udf::ScalarUDF, ExecutionPlan}, physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}, variable::{VarProvider, VarType}, }; +use crate::{functions, functions_aggregate, functions_array}; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; @@ -69,14 +71,11 @@ use datafusion_common::{ SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; -use datafusion_expr::type_coercion::aggregates::NUMERICS; -use datafusion_expr::{create_first_value, Signature, Volatility}; use datafusion_expr::{ logical_plan::{DdlStatement, Statement}, var_provider::is_system_variables, Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; -use datafusion_physical_expr::create_first_value_accumulator; use datafusion_sql::{ parser::{CopyToSource, CopyToStatement, DFParser}, planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel}, @@ -85,7 +84,6 @@ use datafusion_sql::{ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use log::debug; use parking_lot::RwLock; use sqlparser::dialect::dialect_from_str; use url::Url; @@ -1452,29 +1450,16 @@ impl SessionState { }; // register built in functions - datafusion_functions::register_all(&mut new_self) + functions::register_all(&mut new_self) .expect("can not register built in functions"); // register crate of array expressions (if enabled) #[cfg(feature = "array_expressions")] - datafusion_functions_array::register_all(&mut new_self) + functions_array::register_all(&mut new_self) .expect("can not register array expressions"); - let first_value = create_first_value( - "FIRST_VALUE", - Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), - Arc::new(create_first_value_accumulator), - ); - - match new_self.register_udaf(Arc::new(first_value)) { - Ok(Some(existing_udaf)) => { - debug!("Overwrite existing UDAF: {}", existing_udaf.name()); - } - Ok(None) => {} - Err(err) => { - panic!("Failed to register UDAF: {}", err); - } - } + functions_aggregate::register_all(&mut new_self) + .expect("can not register aggregate functions"); new_self } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index f6e2171d6b5f..93eafb8d776e 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -538,6 +538,11 @@ pub mod functions_array { pub use datafusion_functions_array::*; } +/// re-export of [`datafusion_functions_aggregate`] crate +pub mod functions_aggregate { + pub use datafusion_functions_aggregate::*; +} + #[cfg(test)] pub mod test; pub mod test_util; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a1235a093d76..f68685a87f13 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -24,7 +24,6 @@ use crate::expr::{ use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, }; -use crate::udaf::format_state_name; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, @@ -708,17 +707,6 @@ pub fn create_udaf( )) } -/// Creates a new UDAF with a specific signature, state type and return type. -/// The signature and state type must match the `Accumulator's implementation`. -/// TOOD: We plan to move aggregate function to its own crate. This function will be deprecated then. -pub fn create_first_value( - name: &str, - signature: Signature, - accumulator: AccumulatorFactoryFunction, -) -> AggregateUDF { - AggregateUDF::from(FirstValue::new(name, signature, accumulator)) -} - /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. pub struct SimpleAggregateUDF { @@ -813,78 +801,6 @@ impl AggregateUDFImpl for SimpleAggregateUDF { } } -pub struct FirstValue { - name: String, - signature: Signature, - accumulator: AccumulatorFactoryFunction, -} - -impl Debug for FirstValue { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("FirstValue") - .field("name", &self.name) - .field("signature", &self.signature) - .field("accumulator", &"") - .finish() - } -} - -impl FirstValue { - pub fn new( - name: impl Into, - signature: Signature, - accumulator: AccumulatorFactoryFunction, - ) -> Self { - let name = name.into(); - Self { - name, - signature, - accumulator, - } - } -} - -impl AggregateUDFImpl for FirstValue { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - &self.name - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) - } - - fn accumulator( - &self, - acc_args: AccumulatorArgs, - ) -> Result> { - (self.accumulator)(acc_args) - } - - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - let mut fields = vec![Field::new( - format_state_name(name, "first_value"), - value_type, - true, - )]; - fields.extend(ordering_fields); - fields.push(Field::new("is_set", DataType::Boolean, true)); - Ok(fields) - } -} - /// Creates a new UDWF with a specific signature, state type and return type. /// /// The signature and state type must match the [`PartitionEvaluator`]'s implementation`. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 3cf1845aacd6..856f0dc44246 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -19,6 +19,7 @@ use crate::function::AccumulatorArgs; use crate::groups_accumulator::GroupsAccumulator; +use crate::utils::format_state_name; use crate::{Accumulator, Expr}; use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; use arrow::datatypes::{DataType, Field}; @@ -447,9 +448,3 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } - -/// returns the name of the state -/// TODO: Remove duplicated function in physical-expr -pub(crate) fn format_state_name(name: &str, state_name: &str) -> String { - format!("{name}[{state_name}]") -} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 72d01da20448..a93282574e8a 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1240,6 +1240,11 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { } } +/// Build state name. State is the intermidiate state of the aggregate function. +pub fn format_state_name(name: &str, state_name: &str) -> String { + format!("{name}[{state_name}]") +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml new file mode 100644 index 000000000000..d42932d8abdd --- /dev/null +++ b/datafusion/functions-aggregate/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-functions-aggregate" +description = "Aggregate function packages for the DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lib] +name = "datafusion_functions_aggregate" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { workspace = true } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +log = { workspace = true } +paste = "1.0.14" diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs similarity index 85% rename from datafusion/physical-expr/src/aggregate/first_last.rs rename to datafusion/functions-aggregate/src/first_last.rs index 26bd219f65f0..d5367ad34163 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -17,209 +17,149 @@ //! Defines the FIRST_VALUE/LAST_VALUE aggregations. -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::utils::{down_cast_any_ref, get_sort_options, ordering_fields}; -use crate::expressions::{self, format_state_name}; -use crate::{ - reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, -}; - -use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; -use arrow::compute::{self, lexsort_to_indices, SortColumn}; +use arrow::array::{ArrayRef, AsArray, BooleanArray}; +use arrow::compute::{self, lexsort_to_indices, SortColumn, SortOptions}; use arrow::datatypes::{DataType, Field}; -use arrow_schema::SortOptions; use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::function::AccumulatorArgs; -use datafusion_expr::{Accumulator, Expr}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Expr, Signature, Volatility}; +use datafusion_physical_expr_common::aggregate::utils::{ + down_cast_any_ref, get_sort_options, ordering_fields, +}; +use datafusion_physical_expr_common::aggregate::AggregateExpr; +use datafusion_physical_expr_common::expressions; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr_common::utils::reverse_order_bys; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +make_udaf_function!( + FirstValue, + first_value, + value, + "Returns the first value in a group of values.", + first_value_udaf +); -/// FIRST_VALUE aggregate expression -#[derive(Debug, Clone)] pub struct FirstValue { - name: String, - input_data_type: DataType, - order_by_data_types: Vec, - expr: Arc, - ordering_req: LexOrdering, - requirement_satisfied: bool, - ignore_nulls: bool, - state_fields: Vec, + signature: Signature, + aliases: Vec, +} + +impl Debug for FirstValue { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("FirstValue") + .field("name", &self.name()) + .field("signature", &self.signature) + .field("accumulator", &"") + .finish() + } +} + +impl Default for FirstValue { + fn default() -> Self { + Self::new() + } } impl FirstValue { - /// Creates a new FIRST_VALUE aggregation function. - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ordering_req: LexOrdering, - order_by_data_types: Vec, - state_fields: Vec, - ) -> Self { - let requirement_satisfied = ordering_req.is_empty(); + pub fn new() -> Self { Self { - name: name.into(), - input_data_type, - order_by_data_types, - expr, - ordering_req, - requirement_satisfied, - ignore_nulls: false, - state_fields, + aliases: vec![String::from("FIRST_VALUE")], + signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), } } +} - pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { - self.ignore_nulls = ignore_nulls; +impl AggregateUDFImpl for FirstValue { + fn as_any(&self) -> &dyn Any { self } - /// Returns the name of the aggregate expression. - pub fn name(&self) -> &str { - &self.name + fn name(&self) -> &str { + "FIRST_VALUE" } - /// Returns the input data type of the aggregate expression. - pub fn input_data_type(&self) -> &DataType { - &self.input_data_type + fn signature(&self) -> &Signature { + &self.signature } - /// Returns the data types of the order-by columns. - pub fn order_by_data_types(&self) -> &Vec { - &self.order_by_data_types + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) } - /// Returns the expression associated with the aggregate function. - pub fn expr(&self) -> &Arc { - &self.expr - } + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let mut all_sort_orders = vec![]; - /// Returns the lexical ordering requirements of the aggregate expression. - pub fn ordering_req(&self) -> &LexOrdering { - &self.ordering_req - } + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in acc_args.sort_exprs { + if let Expr::Sort(sort) = expr { + if let Expr::Column(col) = sort.expr.as_ref() { + let name = &col.name; + let e = expressions::column::col(name, acc_args.schema)?; + sort_exprs.push(PhysicalSortExpr { + expr: e, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + } + } + if !sort_exprs.is_empty() { + all_sort_orders.extend(sort_exprs); + } - pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self - } + let ordering_req = all_sort_orders; - pub fn convert_to_last(self) -> LastValue { - let name = if self.name.starts_with("FIRST") { - format!("LAST{}", &self.name[5..]) - } else { - format!("LAST_VALUE({})", self.expr) - }; - let FirstValue { - expr, - input_data_type, - ordering_req, - order_by_data_types, - .. - } = self; - LastValue::new( - expr, - name, - input_data_type, - reverse_order_bys(&ordering_req), - order_by_data_types, - ) - } -} - -impl AggregateExpr for FirstValue { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) - } + let requirement_satisfied = ordering_req.is_empty(); - fn create_accumulator(&self) -> Result> { FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, + acc_args.data_type, + &ordering_dtypes, + ordering_req, + acc_args.ignore_nulls, ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) + .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields(&self) -> Result> { - if !self.state_fields.is_empty() { - return Ok(self.state_fields.clone()); - } - + fn state_fields( + &self, + name: &str, + value_type: DataType, + ordering_fields: Vec, + ) -> Result> { let mut fields = vec![Field::new( - format_state_name(&self.name, "first_value"), - self.input_data_type.clone(), + format_state_name(name, "first_value"), + value_type, true, )]; - fields.extend(ordering_fields( - &self.ordering_req, - &self.order_by_data_types, - )); - fields.push(Field::new( - format_state_name(&self.name, "is_set"), - DataType::Boolean, - true, - )); + fields.extend(ordering_fields); + fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } - - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone().convert_to_last())) - } - - fn create_sliding_accumulator(&self) -> Result> { - FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) - } -} - -impl PartialEq for FirstValue { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn aliases(&self) -> &[String] { + &self.aliases } } #[derive(Debug)] -struct FirstValueAccumulator { +pub struct FirstValueAccumulator { first: ScalarValue, // At the beginning, `is_set` is false, which means `first` is not seen yet. // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. @@ -258,6 +198,11 @@ impl FirstValueAccumulator { }) } + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + // Updates state with the values in the given row. fn update_with_new_row(&mut self, row: &[ScalarValue]) { self.first = row[0].clone(); @@ -307,11 +252,6 @@ impl FirstValueAccumulator { Ok((!indices.is_empty()).then_some(indices.value(0) as _)) } } - - fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self - } } impl Accumulator for FirstValueAccumulator { @@ -393,53 +333,190 @@ impl Accumulator for FirstValueAccumulator { } } -pub fn create_first_value_accumulator( - acc_args: AccumulatorArgs, -) -> Result> { - let mut all_sort_orders = vec![]; - - // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; - for expr in acc_args.sort_exprs { - if let Expr::Sort(sort) = expr { - if let Expr::Column(col) = sort.expr.as_ref() { - let name = &col.name; - let e = expressions::col(name, acc_args.schema)?; - sort_exprs.push(PhysicalSortExpr { - expr: e, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } +/// TO BE DEPRECATED: Builtin FIRST_VALUE physical aggregate expression will be replaced by udf in the future +#[derive(Debug, Clone)] +pub struct FirstValuePhysicalExpr { + name: String, + input_data_type: DataType, + order_by_data_types: Vec, + expr: Arc, + ordering_req: LexOrdering, + requirement_satisfied: bool, + ignore_nulls: bool, + state_fields: Vec, +} + +impl FirstValuePhysicalExpr { + /// Creates a new FIRST_VALUE aggregation function. + pub fn new( + expr: Arc, + name: impl Into, + input_data_type: DataType, + ordering_req: LexOrdering, + order_by_data_types: Vec, + state_fields: Vec, + ) -> Self { + let requirement_satisfied = ordering_req.is_empty(); + Self { + name: name.into(), + input_data_type, + order_by_data_types, + expr, + ordering_req, + requirement_satisfied, + ignore_nulls: false, + state_fields, } } - if !sort_exprs.is_empty() { - all_sort_orders.extend(sort_exprs); + + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self } - let ordering_req = all_sort_orders; + /// Returns the name of the aggregate expression. + pub fn name(&self) -> &str { + &self.name + } - let ordering_dtypes = ordering_req - .iter() - .map(|e| e.expr.data_type(acc_args.schema)) - .collect::>>()?; - - let requirement_satisfied = ordering_req.is_empty(); - - FirstValueAccumulator::try_new( - acc_args.data_type, - &ordering_dtypes, - ordering_req, - acc_args.ignore_nulls, - ) - .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) + /// Returns the input data type of the aggregate expression. + pub fn input_data_type(&self) -> &DataType { + &self.input_data_type + } + + /// Returns the data types of the order-by columns. + pub fn order_by_data_types(&self) -> &Vec { + &self.order_by_data_types + } + + /// Returns the expression associated with the aggregate function. + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Returns the lexical ordering requirements of the aggregate expression. + pub fn ordering_req(&self) -> &LexOrdering { + &self.ordering_req + } + + pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } + + pub fn convert_to_last(self) -> LastValuePhysicalExpr { + let name = if self.name.starts_with("FIRST") { + format!("LAST{}", &self.name[5..]) + } else { + format!("LAST_VALUE({})", self.expr) + }; + let FirstValuePhysicalExpr { + expr, + input_data_type, + ordering_req, + order_by_data_types, + .. + } = self; + LastValuePhysicalExpr::new( + expr, + name, + input_data_type, + reverse_order_bys(&ordering_req), + order_by_data_types, + ) + } } -/// LAST_VALUE aggregate expression +impl AggregateExpr for FirstValuePhysicalExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + FirstValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + self.ignore_nulls, + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) + } + + fn state_fields(&self) -> Result> { + if !self.state_fields.is_empty() { + return Ok(self.state_fields.clone()); + } + + let mut fields = vec![Field::new( + format_state_name(&self.name, "first_value"), + self.input_data_type.clone(), + true, + )]; + fields.extend(ordering_fields( + &self.ordering_req, + &self.order_by_data_types, + )); + fields.push(Field::new( + format_state_name(&self.name, "is_set"), + DataType::Boolean, + true, + )); + Ok(fields) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone().convert_to_last())) + } + + fn create_sliding_accumulator(&self) -> Result> { + FirstValueAccumulator::try_new( + &self.input_data_type, + &self.order_by_data_types, + self.ordering_req.clone(), + self.ignore_nulls, + ) + .map(|acc| { + Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ + }) + } +} + +impl PartialEq for FirstValuePhysicalExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.input_data_type == x.input_data_type + && self.order_by_data_types == x.order_by_data_types + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +/// TO BE DEPRECATED: Builtin LAST_VALUE physical aggregate expression will be replaced by udf in the future #[derive(Debug, Clone)] -pub struct LastValue { +pub struct LastValuePhysicalExpr { name: String, input_data_type: DataType, order_by_data_types: Vec, @@ -449,7 +526,7 @@ pub struct LastValue { ignore_nulls: bool, } -impl LastValue { +impl LastValuePhysicalExpr { /// Creates a new LAST_VALUE aggregation function. pub fn new( expr: Arc, @@ -505,20 +582,20 @@ impl LastValue { self } - pub fn convert_to_first(self) -> FirstValue { + pub fn convert_to_first(self) -> FirstValuePhysicalExpr { let name = if self.name.starts_with("LAST") { format!("FIRST{}", &self.name[4..]) } else { format!("FIRST_VALUE({})", self.expr) }; - let LastValue { + let LastValuePhysicalExpr { expr, input_data_type, ordering_req, order_by_data_types, .. } = self; - FirstValue::new( + FirstValuePhysicalExpr::new( expr, name, input_data_type, @@ -529,7 +606,7 @@ impl LastValue { } } -impl AggregateExpr for LastValue { +impl AggregateExpr for LastValuePhysicalExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -598,7 +675,7 @@ impl AggregateExpr for LastValue { } } -impl PartialEq for LastValue { +impl PartialEq for LastValuePhysicalExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() @@ -820,15 +897,9 @@ fn convert_to_sort_cols( #[cfg(test)] mod tests { - use std::sync::Arc; - - use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator}; + use arrow::array::Int64Array; - use arrow::compute::concat; - use arrow_array::{ArrayRef, Int64Array}; - use arrow_schema::DataType; - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::Accumulator; + use super::*; #[test] fn test_first_last_value_value() -> Result<()> { @@ -888,7 +959,7 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[ + states.push(arrow::compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); @@ -918,7 +989,7 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(concat(&[ + states.push(arrow::compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs new file mode 100644 index 000000000000..8016b76889f7 --- /dev/null +++ b/datafusion/functions-aggregate/src/lib.rs @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Aggregate Function packages for [DataFusion]. +//! +//! This crate contains a collection of various aggregate function packages for DataFusion, +//! implemented using the extension API. Users may wish to control which functions +//! are available to control the binary size of their application as well as +//! use dialect specific implementations of functions (e.g. Spark vs Postgres) +//! +//! Each package is implemented as a separate +//! module, activated by a feature flag. +//! +//! [DataFusion]: https://crates.io/crates/datafusion +//! +//! # Available Packages +//! See the list of [modules](#modules) in this crate for available packages. +//! +//! # Using A Package +//! You can register all functions in all packages using the [`register_all`] function. +//! +//! Each package also exports an `expr_fn` submodule to help create [`Expr`]s that invoke +//! functions using a fluent style. For example: +//! +//![`Expr`]: datafusion_expr::Expr +//! +//! # Implementing A New Package +//! +//! To add a new package to this crate, you should follow the model of existing +//! packages. The high level steps are: +//! +//! 1. Create a new module with the appropriate [AggregateUDF] implementations. +//! +//! 2. Use the macros in [`macros`] to create standard entry points. +//! +//! 3. Add a new feature to `Cargo.toml`, with any optional dependencies +//! +//! 4. Use the `make_package!` macro to expose the module when the +//! feature is enabled. + +#[macro_use] +pub mod macros; + +pub mod first_last; + +use datafusion_common::Result; +use datafusion_execution::FunctionRegistry; +use datafusion_expr::AggregateUDF; +use log::debug; +use std::sync::Arc; + +/// Fluent-style API for creating `Expr`s +pub mod expr_fn { + pub use super::first_last::first_value; +} + +/// Registers all enabled packages with a [`FunctionRegistry`] +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + let functions: Vec> = vec![first_last::first_value_udaf()]; + + functions.into_iter().try_for_each(|udf| { + let existing_udaf = registry.register_udaf(udf)?; + if let Some(existing_udaf) = existing_udaf { + debug!("Overwrite existing UDAF: {}", existing_udaf.name()); + } + Ok(()) as Result<()> + })?; + + Ok(()) +} diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs new file mode 100644 index 000000000000..d24c60f93270 --- /dev/null +++ b/datafusion/functions-aggregate/src/macros.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +macro_rules! make_udaf_function { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + paste::paste! { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN($($arg: Expr),*) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + $AGGREGATE_UDF_FN(), + vec![$($arg),*], + // TODO: Support arguments for `expr` API + false, + None, + None, + None, + )) + } + + /// Singleton instance of [$UDAF], ensures the UDAF is only created once + /// named STATIC_$(UDAF). For example `STATIC_FirstValue` + #[allow(non_upper_case_globals)] + static [< STATIC_ $UDAF >]: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + /// AggregateFunction that returns a [AggregateUDF] for [$UDAF] + /// + /// [AggregateUDF]: datafusion_expr::AggregateUDF + pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { + [< STATIC_ $UDAF >] + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default())) + }) + .clone() + } + } + } +} diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 579f51815d84..33044fd9beee 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -17,16 +17,54 @@ pub mod utils; -use std::any::Any; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, +}; use std::fmt::Debug; -use std::sync::Arc; +use std::{any::Any, sync::Arc}; use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::PhysicalSortExpr; +use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; -use arrow::datatypes::Field; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::{Accumulator, GroupsAccumulator}; +use self::utils::{down_cast_any_ref, ordering_fields}; + +/// Creates a physical expression of the UDAF, that includes all necessary type coercion. +/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. +pub fn create_aggregate_expr( + fun: &AggregateUDF, + input_phy_exprs: &[Arc], + sort_exprs: &[Expr], + ordering_req: &[PhysicalSortExpr], + schema: &Schema, + name: impl Into, + ignore_nulls: bool, +) -> Result> { + let input_exprs_types = input_phy_exprs + .iter() + .map(|arg| arg.data_type(schema)) + .collect::>>()?; + + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(schema)) + .collect::>>()?; + + let ordering_fields = ordering_fields(ordering_req, &ordering_types); + + Ok(Arc::new(AggregateFunctionExpr { + fun: fun.clone(), + args: input_phy_exprs.to_vec(), + data_type: fun.return_type(&input_exprs_types)?, + name: name.into(), + schema: schema.clone(), + sort_exprs: sort_exprs.to_vec(), + ordering_req: ordering_req.to_vec(), + ignore_nulls, + ordering_fields, + })) +} /// An aggregate expression that: /// * knows its resulting field @@ -100,3 +138,151 @@ pub trait AggregateExpr: Send + Sync + Debug + PartialEq { not_impl_err!("Retractable Accumulator hasn't been implemented for {self:?} yet") } } + +/// Physical aggregate expression of a UDAF. +#[derive(Debug)] +pub struct AggregateFunctionExpr { + fun: AggregateUDF, + args: Vec>, + /// Output / return type of this aggregate + data_type: DataType, + name: String, + schema: Schema, + // The logical order by expressions + sort_exprs: Vec, + // The physical order by expressions + ordering_req: LexOrdering, + ignore_nulls: bool, + ordering_fields: Vec, +} + +impl AggregateFunctionExpr { + /// Return the `AggregateUDF` used by this `AggregateFunctionExpr` + pub fn fun(&self) -> &AggregateUDF { + &self.fun + } +} + +impl AggregateExpr for AggregateFunctionExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn expressions(&self) -> Vec> { + self.args.clone() + } + + fn state_fields(&self) -> Result> { + self.fun.state_fields( + self.name(), + self.data_type.clone(), + self.ordering_fields.clone(), + ) + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + let acc_args = AccumulatorArgs::new( + &self.data_type, + &self.schema, + self.ignore_nulls, + &self.sort_exprs, + ); + + self.fun.accumulator(acc_args) + } + + fn create_sliding_accumulator(&self) -> Result> { + let accumulator = self.create_accumulator()?; + + // Accumulators that have window frame startings different + // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to + // implement retract_batch method in order to run correctly + // currently in DataFusion. + // + // If this `retract_batches` is not present, there is no way + // to calculate result correctly. For example, the query + // + // ```sql + // SELECT + // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a + // FROM + // t + // ``` + // + // 1. First sum value will be the sum of rows between `[0, 1)`, + // + // 2. Second sum value will be the sum of rows between `[0, 2)` + // + // 3. Third sum value will be the sum of rows between `[1, 3)`, etc. + // + // Since the accumulator keeps the running sum: + // + // 1. First sum we add to the state sum value between `[0, 1)` + // + // 2. Second sum we add to the state sum value between `[1, 2)` + // (`[0, 1)` is already in the state sum, hence running sum will + // cover `[0, 2)` range) + // + // 3. Third sum we add to the state sum value between `[2, 3)` + // (`[0, 2)` is already in the state sum). Also we need to + // retract values between `[0, 1)` by this way we can obtain sum + // between [1, 3) which is indeed the apropriate range. + // + // When we use `UNBOUNDED PRECEDING` in the query starting + // index will always be 0 for the desired range, and hence the + // `retract_batch` method will not be called. In this case + // having retract_batch is not a requirement. + // + // This approach is a a bit different than window function + // approach. In window function (when they use a window frame) + // they get all the desired range during evaluation. + if !accumulator.supports_retract_batch() { + return not_impl_err!( + "Aggregate can not be used as a sliding accumulator because \ + `retract_batch` is not implemented: {}", + self.name + ); + } + Ok(accumulator) + } + + fn name(&self) -> &str { + &self.name + } + + fn groups_accumulator_supported(&self) -> bool { + self.fun.groups_accumulator_supported() + } + + fn create_groups_accumulator(&self) -> Result> { + self.fun.create_groups_accumulator() + } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } +} + +impl PartialEq for AggregateFunctionExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.fun == x.fun + && self.args.len() == x.args.len() + && self + .args + .iter() + .zip(x.args.iter()) + .all(|(this_arg, other_arg)| this_arg.eq(other_arg)) + }) + .unwrap_or(false) + } +} diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 87d73183d0dd..72fac5370ae0 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -59,6 +59,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } hashbrown = { version = "0.14", features = ["raw"] } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index e176084ae6ec..eff008e8f825 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -38,7 +38,6 @@ pub(crate) mod correlation; pub(crate) mod count; pub(crate) mod count_distinct; pub(crate) mod covariance; -pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; pub(crate) mod nth_value; diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index d14a52f5752d..6d97ad3da6de 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -20,9 +20,9 @@ use std::sync::Arc; // For backwards compatibility -pub use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref; -pub use datafusion_physical_expr_common::aggregate::utils::get_sort_options; -pub use datafusion_physical_expr_common::aggregate::utils::ordering_fields; +pub use datafusion_physical_expr_common::aggregate::utils::{ + down_cast_any_ref, get_sort_options, ordering_fields, +}; use arrow::array::{ArrayRef, ArrowNativeTypeOp}; use arrow_array::cast::AsArray; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index f0cc4b175ea5..688d5ce6eabf 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -53,7 +53,6 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::first_last::{FirstValue, LastValue}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; @@ -76,11 +75,15 @@ pub use crate::window::rank::{dense_rank, percent_rank, rank}; pub use crate::window::rank::{Rank, RankType}; pub use crate::window::row_number::RowNumber; pub use crate::PhysicalSortExpr; +pub use datafusion_functions_aggregate::first_last::{ + FirstValuePhysicalExpr as FirstValue, LastValuePhysicalExpr as LastValue, +}; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; pub use cast::{cast, cast_with_options, CastExpr}; pub use column::UnKnownColumn; +pub use datafusion_expr::utils::format_state_name; pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; @@ -92,11 +95,6 @@ pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use try_cast::{try_cast, TryCastExpr}; -/// returns the name of the state -pub fn format_state_name(name: &str, state_name: &str) -> String { - format!("{name}[{state_name}]") -} - #[cfg(test)] pub(crate) mod tests { use std::sync::Arc; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index c88f1b32bbc6..7b81e8f8a5c4 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -61,8 +61,6 @@ pub use scalar_function::ScalarFunctionExpr; pub use datafusion_physical_expr_common::utils::reverse_order_bys; pub use utils::split_conjunction; -pub use aggregate::first_last::create_first_value_accumulator; - // For backwards compatibility pub mod sort_properties { pub use datafusion_physical_expr_common::sort_properties::{ diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 1ba32bff746e..6a78bd596a46 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -46,7 +46,9 @@ datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } +datafusion-physical-expr-common = { workspace = true } futures = { workspace = true } half = { workspace = true } hashbrown = { version = "0.14", features = ["raw"] } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index f8ad03bf6d97..98c44e23c6c7 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1235,7 +1235,7 @@ mod tests { use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, + lit, ApproxDistinct, Count, LastValue, Median, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{ reverse_order_bys, AggregateExpr, EquivalenceProperties, PhysicalExpr, diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 3decf2e34015..e1c8489655bf 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -66,7 +66,6 @@ pub mod sorts; pub mod stream; pub mod streaming; pub mod tree_node; -pub mod udaf; pub mod union; pub mod unnest; pub mod values; @@ -91,6 +90,11 @@ pub use datafusion_physical_expr::{ // Backwards compatibility pub use crate::stream::EmptyRecordBatchStream; pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +pub mod udaf { + pub use datafusion_physical_expr_common::aggregate::{ + create_aggregate_expr, AggregateFunctionExpr, + }; +} /// Represent nodes in the DataFusion Physical Plan. /// diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs deleted file mode 100644 index 74a5603c0c81..000000000000 --- a/datafusion/physical-plan/src/udaf.rs +++ /dev/null @@ -1,218 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This module contains functions and structs supporting user-defined aggregate functions. - -use datafusion_expr::function::AccumulatorArgs; -use datafusion_expr::{Expr, GroupsAccumulator}; -use fmt::Debug; -use std::any::Any; -use std::fmt; - -use arrow::datatypes::{DataType, Field, Schema}; - -use super::{Accumulator, AggregateExpr}; -use datafusion_common::{not_impl_err, Result}; -pub use datafusion_expr::AggregateUDF; -use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; - -use datafusion_physical_expr::aggregate::utils::{down_cast_any_ref, ordering_fields}; -use std::sync::Arc; - -/// Creates a physical expression of the UDAF, that includes all necessary type coercion. -/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. -pub fn create_aggregate_expr( - fun: &AggregateUDF, - input_phy_exprs: &[Arc], - sort_exprs: &[Expr], - ordering_req: &[PhysicalSortExpr], - schema: &Schema, - name: impl Into, - ignore_nulls: bool, -) -> Result> { - let input_exprs_types = input_phy_exprs - .iter() - .map(|arg| arg.data_type(schema)) - .collect::>>()?; - - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(schema)) - .collect::>>()?; - - let ordering_fields = ordering_fields(ordering_req, &ordering_types); - - Ok(Arc::new(AggregateFunctionExpr { - fun: fun.clone(), - args: input_phy_exprs.to_vec(), - data_type: fun.return_type(&input_exprs_types)?, - name: name.into(), - schema: schema.clone(), - sort_exprs: sort_exprs.to_vec(), - ordering_req: ordering_req.to_vec(), - ignore_nulls, - ordering_fields, - })) -} - -/// Physical aggregate expression of a UDAF. -#[derive(Debug)] -pub struct AggregateFunctionExpr { - fun: AggregateUDF, - args: Vec>, - /// Output / return type of this aggregate - data_type: DataType, - name: String, - schema: Schema, - // The logical order by expressions - sort_exprs: Vec, - // The physical order by expressions - ordering_req: LexOrdering, - ignore_nulls: bool, - ordering_fields: Vec, -} - -impl AggregateFunctionExpr { - /// Return the `AggregateUDF` used by this `AggregateFunctionExpr` - pub fn fun(&self) -> &AggregateUDF { - &self.fun - } -} - -impl AggregateExpr for AggregateFunctionExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn expressions(&self) -> Vec> { - self.args.clone() - } - - fn state_fields(&self) -> Result> { - self.fun.state_fields( - self.name(), - self.data_type.clone(), - self.ordering_fields.clone(), - ) - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - let acc_args = AccumulatorArgs::new( - &self.data_type, - &self.schema, - self.ignore_nulls, - &self.sort_exprs, - ); - - self.fun.accumulator(acc_args) - } - - fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.create_accumulator()?; - - // Accumulators that have window frame startings different - // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to - // implement retract_batch method in order to run correctly - // currently in DataFusion. - // - // If this `retract_batches` is not present, there is no way - // to calculate result correctly. For example, the query - // - // ```sql - // SELECT - // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a - // FROM - // t - // ``` - // - // 1. First sum value will be the sum of rows between `[0, 1)`, - // - // 2. Second sum value will be the sum of rows between `[0, 2)` - // - // 3. Third sum value will be the sum of rows between `[1, 3)`, etc. - // - // Since the accumulator keeps the running sum: - // - // 1. First sum we add to the state sum value between `[0, 1)` - // - // 2. Second sum we add to the state sum value between `[1, 2)` - // (`[0, 1)` is already in the state sum, hence running sum will - // cover `[0, 2)` range) - // - // 3. Third sum we add to the state sum value between `[2, 3)` - // (`[0, 2)` is already in the state sum). Also we need to - // retract values between `[0, 1)` by this way we can obtain sum - // between [1, 3) which is indeed the apropriate range. - // - // When we use `UNBOUNDED PRECEDING` in the query starting - // index will always be 0 for the desired range, and hence the - // `retract_batch` method will not be called. In this case - // having retract_batch is not a requirement. - // - // This approach is a a bit different than window function - // approach. In window function (when they use a window frame) - // they get all the desired range during evaluation. - if !accumulator.supports_retract_batch() { - return not_impl_err!( - "Aggregate can not be used as a sliding accumulator because \ - `retract_batch` is not implemented: {}", - self.name - ); - } - Ok(accumulator) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - self.fun.groups_accumulator_supported() - } - - fn create_groups_accumulator(&self) -> Result> { - self.fun.create_groups_accumulator() - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } -} - -impl PartialEq for AggregateFunctionExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.fun == x.fun - && self.args.len() == x.args.len() - && self - .args - .iter() - .zip(x.args.iter()) - .all(|(this_arg, other_arg)| this_arg.eq(other_arg)) - }) - .unwrap_or(false) - } -} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f136e314559b..e680a1b2ff1e 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -30,6 +30,7 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::functions_aggregate::expr_fn::first_value; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -612,6 +613,7 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), + first_value(lit(1)), ]; // ensure expressions created with the expr api can be round tripped