Skip to content

Commit

Permalink
backup
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed May 5, 2024
1 parent 745b1c6 commit 7d47526
Show file tree
Hide file tree
Showing 9 changed files with 634 additions and 34 deletions.
40 changes: 39 additions & 1 deletion datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::ColumnarValue;
use crate::{Accumulator, Expr, PartitionEvaluator};
use arrow::datatypes::{DataType, Schema};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
use std::sync::Arc;

Expand All @@ -38,6 +38,39 @@ pub type ScalarFunctionImplementation =
pub type ReturnTypeFunction =
Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;

/// [`StateFieldsArgs`] contains information about the state fields that an
/// aggregate function requires. See [AggregateFunctionExpr] for more information about
/// how these fields are.
pub struct StateFieldsArgs<'a> {
pub name: &'a str,
pub value_type: DataType,
pub ordering_fields: Vec<Field>,
pub nullable: bool,
}

impl<'a> StateFieldsArgs<'a> {
pub fn new(
name: &'a str,
value_type: DataType,
ordering_fields: Vec<Field>,
nullable: bool,
) -> Self {
Self {
name,
value_type,
ordering_fields,
nullable,
}
}
}

pub struct GroupsAccumulatorArgs<'a> {
/// The return type of the aggregate function.
pub data_type: &'a DataType,
/// The name of the aggregate expression
pub name: &'a str,
}

/// [`AccumulatorArgs`] contains information about how an aggregate
/// function was called, including the types of its arguments and any optional
/// ordering expressions.
Expand Down Expand Up @@ -66,6 +99,9 @@ pub struct AccumulatorArgs<'a> {
///
/// If no `ORDER BY` is specified, `sort_exprs`` will be empty.
pub sort_exprs: &'a [Expr],

/// The name of the aggregate expression
pub name: &'a str,
}

impl<'a> AccumulatorArgs<'a> {
Expand All @@ -74,12 +110,14 @@ impl<'a> AccumulatorArgs<'a> {
schema: &'a Schema,
ignore_nulls: bool,
sort_exprs: &'a [Expr],
name: &'a str,
) -> Self {
Self {
data_type,
schema,
ignore_nulls,
sort_exprs,
name,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub use signature::{
TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
47 changes: 43 additions & 4 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::function::AccumulatorArgs;
use crate::function::{AccumulatorArgs, GroupsAccumulatorArgs};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
Expand Down Expand Up @@ -179,11 +179,13 @@ impl AggregateUDF {
/// This is used to support multi-phase aggregations
pub fn state_fields(
&self,
// state_fields_args: StateFieldsArgs,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
) -> Result<Vec<Field>> {
self.inner.state_fields(name, value_type, ordering_fields)
// self.inner.state_fields(state_fields_args)
}

/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
Expand All @@ -192,8 +194,22 @@ impl AggregateUDF {
}

/// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator()
pub fn create_groups_accumulator(
&self,
args: GroupsAccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator(args)
}

pub fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
self.inner.create_sliding_accumulator(args)
}

pub fn reverse_expr(&self) -> ReversedUDAF {
self.inner.reverse_expr()
}
}

Expand Down Expand Up @@ -343,7 +359,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
///
/// For maximum performance, a [`GroupsAccumulator`] should be
/// implemented in addition to [`Accumulator`].
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
fn create_groups_accumulator(
&self,
_args: GroupsAccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet")
}

Expand All @@ -354,6 +373,26 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn aliases(&self) -> &[String] {
&[]
}

fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
self.accumulator(args)
}

fn reverse_expr(&self) -> ReversedUDAF {
ReversedUDAF::NotSupported
}
}

pub enum ReversedUDAF {
/// The expression is the same as the original expression, like SUM, COUNT
Identical,
/// The expression does not support reverse calculation, like ArrayAgg
NotSupported,
/// The expression is different from the original expression
Reversed(Arc<dyn AggregateUDFImpl>),
}

/// AggregateUDF that adds an alias to the underlying function. It is better to
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
pub mod macros;

pub mod first_last;
pub mod sum;

use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
Expand Down
68 changes: 47 additions & 21 deletions datafusion/functions-aggregate/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,59 @@
// specific language governing permissions and limitations
// under the License.

macro_rules! make_udaf_function {
macro_rules! make_udaf_expr_and_func {
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN(
$($arg: datafusion_expr::Expr,)*
distinct: bool,
filter: Option<Box<datafusion_expr::Expr>>,
order_by: Option<Vec<datafusion_expr::Expr>>,
null_treatment: Option<sqlparser::ast::NullTreatment>
) -> datafusion_expr::Expr {
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
vec![$($arg),*],
distinct,
filter,
order_by,
null_treatment,
))
}
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
paste::paste! {
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN(
args: Vec<Expr>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
null_treatment: Option<NullTreatment>
) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
args,
distinct,
filter,
order_by,
null_treatment,
))
}
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN(
args: Vec<datafusion_expr::Expr>,
distinct: bool,
filter: Option<Box<datafusion_expr::Expr>>,
order_by: Option<Vec<datafusion_expr::Expr>>,
null_treatment: Option<sqlparser::ast::NullTreatment>
) -> datafusion_expr::Expr {
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
args,
distinct,
filter,
order_by,
null_treatment,
))
}
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
}

macro_rules! create_func {
($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
paste::paste! {
/// Singleton instance of [$UDAF], ensures the UDAF is only created once
/// named STATIC_$(UDAF). For example `STATIC_FirstValue`
#[allow(non_upper_case_globals)]
static [< STATIC_ $UDAF >]: std::sync::OnceLock<std::sync::Arc<datafusion_expr::AggregateUDF>> =
std::sync::OnceLock::new();
std::sync::OnceLock::new();

/// AggregateFunction that returns a [AggregateUDF] for [$UDAF]
///
Expand Down
Loading

0 comments on commit 7d47526

Please sign in to comment.