Skip to content

Commit

Permalink
fix with args
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed May 13, 2024
1 parent 9917b37 commit c6ba32b
Show file tree
Hide file tree
Showing 15 changed files with 112 additions and 121 deletions.
3 changes: 3 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 8 additions & 10 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
GroupsAccumulator, Signature,
function::{AccumulatorArgs, GroupsAccumulatorSupportedArgs, StateFieldsArgs},
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the full AggregateUDFImpl API to implement a user
Expand Down Expand Up @@ -92,21 +92,19 @@ impl AggregateUDFImpl for GeoMeanUdaf {
}

/// This is the description of the state. accumulator's state() must match the types here.
fn state_fields(
&self,
_name: &str,
value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<arrow_schema::Field>> {
Ok(vec![
Field::new("prod", value_type, true),
Field::new("prod", args.return_type.clone(), true),
Field::new("n", DataType::UInt32, true),
])
}

/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator`
/// which is used for cases when there are grouping columns in the query
fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(
&self,
_args: GroupsAccumulatorSupportedArgs,
) -> bool {
true
}

Expand Down
16 changes: 8 additions & 8 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::function::AggregateFunctionSimplification;
use datafusion_expr::function::{
AggregateFunctionSimplification, GroupsAccumulatorSupportedArgs, StateFieldsArgs,
};
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -70,16 +72,14 @@ impl AggregateUDFImpl for BetterAvgUdaf {
unimplemented!("should not be invoked")
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<arrow_schema::Field>> {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(
&self,
_args: GroupsAccumulatorSupportedArgs,
) -> bool {
true
}

Expand Down
10 changes: 7 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
create_udaf,
function::{AccumulatorArgs, GroupsAccumulatorSupportedArgs},
AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;

Expand Down Expand Up @@ -725,7 +726,10 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
panic!("accumulator shouldn't invoke");
}

fn groups_accumulator_supported(&self) -> bool {
fn groups_accumulator_supported(
&self,
_args: GroupsAccumulatorSupportedArgs,
) -> bool {
true
}

Expand Down
10 changes: 2 additions & 8 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::expr::{
};
use crate::function::{
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
StateFieldsArgs,
};
use crate::{
aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery,
Expand Down Expand Up @@ -692,14 +693,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
(self.accumulator)(acc_args)
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
_is_distinct: bool,
_input_type: DataType,
) -> Result<Vec<Field>> {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(self.state_fields.clone())
}
}
Expand Down
17 changes: 16 additions & 1 deletion datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use crate::ColumnarValue;
use crate::{Accumulator, Expr, PartitionEvaluator};
use arrow::datatypes::{DataType, Schema};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
use std::sync::Arc;

Expand Down Expand Up @@ -91,6 +91,21 @@ impl<'a> AccumulatorArgs<'a> {
}
}

/// [`GroupsAccumulatorSupportedArgs`] contains information to determine if an
/// aggregate function supports the groups accumulator.
pub struct GroupsAccumulatorSupportedArgs {
pub args_num: usize,
pub is_distinct: bool,
}

pub struct StateFieldsArgs<'a> {
pub name: &'a str,
pub input_type: &'a DataType,
pub return_type: &'a DataType,
pub ordering_fields: &'a [Field],
pub is_distinct: bool,
}

/// Factory that returns an accumulator for the given aggregate function.
pub type AccumulatorFactoryFunction =
Arc<dyn Fn(AccumulatorArgs) -> Result<Box<dyn Accumulator>> + Send + Sync>;
Expand Down
53 changes: 20 additions & 33 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, GroupsAccumulatorSupportedArgs,
StateFieldsArgs,
};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
Expand Down Expand Up @@ -177,31 +180,16 @@ impl AggregateUDF {
/// for more details.
///
/// This is used to support multi-phase aggregations
pub fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
is_distinct: bool,
input_type: DataType,
) -> Result<Vec<Field>> {
self.inner.state_fields(
name,
value_type,
ordering_fields,
is_distinct,
input_type,
)
pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
self.inner.state_fields(args)
}

/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details.
pub fn groups_accumulator_supported(
&self,
args_num: usize,
is_distinct: bool,
args: GroupsAccumulatorSupportedArgs,
) -> bool {
self.inner
.groups_accumulator_supported(args_num, is_distinct)
self.inner.groups_accumulator_supported(args)
}

/// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
Expand Down Expand Up @@ -338,21 +326,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// The name of the fields must be unique within the query and thus should
/// be derived from `name`. See [`format_state_name`] for a utility function
/// to generate a unique name.
fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
_is_distinct: bool,
_input_type: DataType,
) -> Result<Vec<Field>> {
let value_fields = vec![Field::new(
format_state_name(name, "value"),
value_type,
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let fields = vec![Field::new(
format_state_name(args.name, "value"),
args.return_type.clone(),
true,
)];

Ok(value_fields.into_iter().chain(ordering_fields).collect())
Ok(fields
.into_iter()
.chain(args.ordering_fields.to_vec())
.collect())
}

/// If the aggregate expression has a specialized
Expand All @@ -365,7 +349,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// `Self::accumulator` for certain queries, such as when this aggregate is
/// used as a window function or when there no GROUP BY columns in the
/// query.
fn groups_accumulator_supported(&self, _args_num: usize, _is_distinct: bool) -> bool {
fn groups_accumulator_supported(
&self,
_args: GroupsAccumulatorSupportedArgs,
) -> bool {
false
}

Expand Down
24 changes: 9 additions & 15 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use datafusion_common::{
downcast_value, internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::function::{GroupsAccumulatorSupportedArgs, StateFieldsArgs};
use datafusion_expr::Expr;
use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
Expand Down Expand Up @@ -120,23 +121,16 @@ impl AggregateUDFImpl for Count {
Ok(DataType::Int64)
}

fn state_fields(
&self,
name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
is_distinct: bool,
input_type: DataType,
) -> Result<Vec<Field>> {
if is_distinct {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
if args.is_distinct {
Ok(vec![Field::new_list(
format_state_name(name, "count distinct"),
Field::new("item", input_type, true),
format_state_name(args.name, "count distinct"),
Field::new("item", args.input_type.clone(), true),
false,
)])
} else {
Ok(vec![Field::new(
format_state_name(name, "count"),
format_state_name(args.name, "count"),
DataType::Int64,
true,
)])
Expand Down Expand Up @@ -258,13 +252,13 @@ impl AggregateUDFImpl for Count {
&self.aliases
}

fn groups_accumulator_supported(&self, args_num: usize, is_distinct: bool) -> bool {
fn groups_accumulator_supported(&self, args: GroupsAccumulatorSupportedArgs) -> bool {
// groups accumulator only supports `COUNT(c1)`, not
// `COUNT(c1, c2)`, etc
if is_distinct {
if args.is_distinct {
return false;
}
args_num == 1
args.args_num == 1
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
Expand Down
26 changes: 8 additions & 18 deletions datafusion/functions-aggregate/src/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ use datafusion_common::{
ScalarValue,
};
use datafusion_expr::{
function::AccumulatorArgs, type_coercion::aggregates::NUMERICS,
utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility,
function::{AccumulatorArgs, StateFieldsArgs},
type_coercion::aggregates::NUMERICS,
utils::format_state_name,
Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_physical_expr_common::aggregate::stats::StatsType;

Expand Down Expand Up @@ -101,14 +103,8 @@ impl AggregateUDFImpl for CovarianceSample {
Ok(DataType::Float64)
}

fn state_fields(
&self,
name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
_is_distinct: bool,
_input_type: DataType,
) -> Result<Vec<Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Expand Down Expand Up @@ -178,14 +174,8 @@ impl AggregateUDFImpl for CovariancePopulation {
Ok(DataType::Float64)
}

fn state_fields(
&self,
name: &str,
_value_type: DataType,
_ordering_fields: Vec<Field>,
_is_distinct: bool,
_input_type: DataType,
) -> Result<Vec<Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
Expand Down
17 changes: 5 additions & 12 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at
use datafusion_common::{
arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Expand Down Expand Up @@ -147,20 +147,13 @@ impl AggregateUDFImpl for FirstValue {
.map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
}

fn state_fields(
&self,
name: &str,
value_type: DataType,
ordering_fields: Vec<Field>,
_is_distinct: bool,
_input_type: DataType,
) -> Result<Vec<Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let mut fields = vec![Field::new(
format_state_name(name, "first_value"),
value_type,
format_state_name(args.name, "first_value"),
args.return_type.clone(),
true,
)];
fields.extend(ordering_fields);
fields.extend(args.ordering_fields.to_vec());
fields.push(Field::new("is_set", DataType::Boolean, true));
Ok(fields)
}
Expand Down
Loading

0 comments on commit c6ba32b

Please sign in to comment.