Skip to content

Commit

Permalink
aggregate expr builder
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Jul 23, 2024
1 parent deef834 commit 5d98c32
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 90 deletions.
20 changes: 6 additions & 14 deletions datafusion/core/src/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ pub(crate) mod tests {
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::expressions::cast;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
use datafusion_physical_plan::aggregates::AggregateMode;

/// Mock data using a MemoryExec which has an exact count statistic
Expand Down Expand Up @@ -419,19 +419,11 @@ pub(crate) mod tests {

// Return appropriate expr depending if COUNT is for col or table (*)
pub(crate) fn count_expr(&self, schema: &Schema) -> Arc<dyn AggregateExpr> {
create_aggregate_expr(
&count_udaf(),
&[self.column()],
&[],
&[],
&[],
schema,
self.column_name(),
false,
false,
false,
)
.unwrap()
AggregateExprBuilder::new(count_udaf(), vec![self.column()])
.schema(schema.clone())
.name(self.column_name())
.build()
.unwrap()
}

/// what argument would this aggregate need in the plan?
Expand Down
266 changes: 190 additions & 76 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub mod tdigest;
pub mod utils;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{not_impl_err, DFSchema, Result};
use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::ReversedUDAF;
Expand All @@ -33,7 +33,7 @@ use datafusion_expr::{
use std::fmt::Debug;
use std::{any::Any, sync::Arc};

use self::utils::{down_cast_any_ref, ordering_fields};
use self::utils::down_cast_any_ref;
use crate::physical_expr::PhysicalExpr;
use crate::sort_expr::{LexOrdering, PhysicalSortExpr};
use crate::utils::reverse_order_bys;
Expand All @@ -55,6 +55,8 @@ use datafusion_expr::utils::AggregateOrderSensitivity;
/// `is_reversed` is used to indicate whether the aggregation is running in reverse order,
/// it could be used to hint Accumulator to accumulate in the reversed order,
/// you can just set to false if you are not reversing expression
///
/// You can also create expression by [`AggregateExprBuilder`]
#[allow(clippy::too_many_arguments)]
pub fn create_aggregate_expr(
fun: &AggregateUDF,
Expand All @@ -66,45 +68,24 @@ pub fn create_aggregate_expr(
name: impl Into<String>,
ignore_nulls: bool,
is_distinct: bool,
is_reversed: bool,
_is_reversed: bool,
) -> Result<Arc<dyn AggregateExpr>> {
debug_assert_eq!(sort_exprs.len(), ordering_req.len());

let input_exprs_types = input_phy_exprs
.iter()
.map(|arg| arg.data_type(schema))
.collect::<Result<Vec<_>>>()?;

check_arg_count(
fun.name(),
&input_exprs_types,
&fun.signature().type_signature,
)?;

let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(schema))
.collect::<Result<Vec<_>>>()?;

let ordering_fields = ordering_fields(ordering_req, &ordering_types);
let name = name.into();

Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: input_phy_exprs.to_vec(),
logical_args: input_exprs.to_vec(),
data_type: fun.return_type(&input_exprs_types)?,
name,
schema: schema.clone(),
dfschema: DFSchema::empty(),
sort_exprs: sort_exprs.to_vec(),
ordering_req: ordering_req.to_vec(),
ignore_nulls,
ordering_fields,
is_distinct,
input_type: input_exprs_types[0].clone(),
is_reversed,
}))
let mut builder =
AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec());
builder = builder.sort_exprs(sort_exprs.to_vec());
builder = builder.order_by(ordering_req.to_vec());
builder = builder.logical_exprs(input_exprs.to_vec());
builder = builder.schema(schema.clone());
builder = builder.name(name);

if ignore_nulls {
builder = builder.ignore_nulls();
}
if is_distinct {
builder = builder.distinct();
}

builder.build()
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -121,44 +102,177 @@ pub fn create_aggregate_expr_with_dfschema(
is_distinct: bool,
is_reversed: bool,
) -> Result<Arc<dyn AggregateExpr>> {
debug_assert_eq!(sort_exprs.len(), ordering_req.len());

let mut builder =
AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec());
builder = builder.sort_exprs(sort_exprs.to_vec());
builder = builder.order_by(ordering_req.to_vec());
builder = builder.logical_exprs(input_exprs.to_vec());
builder = builder.dfschema(dfschema.clone());
let schema: Schema = dfschema.into();
builder = builder.schema(schema);
builder = builder.name(name);

if ignore_nulls {
builder = builder.ignore_nulls();
}
if is_distinct {
builder = builder.distinct();
}
if is_reversed {
builder = builder.reversed();
}

builder.build()
}

let input_exprs_types = input_phy_exprs
.iter()
.map(|arg| arg.data_type(&schema))
.collect::<Result<Vec<_>>>()?;

check_arg_count(
fun.name(),
&input_exprs_types,
&fun.signature().type_signature,
)?;

let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(&schema))
.collect::<Result<Vec<_>>>()?;

let ordering_fields = ordering_fields(ordering_req, &ordering_types);

Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: input_phy_exprs.to_vec(),
logical_args: input_exprs.to_vec(),
data_type: fun.return_type(&input_exprs_types)?,
name: name.into(),
schema: schema.clone(),
dfschema: dfschema.clone(),
sort_exprs: sort_exprs.to_vec(),
ordering_req: ordering_req.to_vec(),
ignore_nulls,
ordering_fields,
is_distinct,
input_type: input_exprs_types[0].clone(),
is_reversed,
}))
#[derive(Debug, Clone)]
pub struct AggregateExprBuilder {
fun: Arc<AggregateUDF>,
/// Physical expressions of the aggregate function
args: Vec<Arc<dyn PhysicalExpr>>,
/// Logical expressions of the aggregate function, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
logical_args: Vec<Expr>,
name: String,
/// Arrow Schema for the aggregate function
schema: Schema,
/// Datafusion Schema for the aggregate function
dfschema: DFSchema,
/// The logical order by expressions, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
sort_exprs: Vec<Expr>,
/// The physical order by expressions
ordering_req: LexOrdering,
/// Whether to ignore null values
ignore_nulls: bool,
/// Whether is distinct aggregate function
is_distinct: bool,
/// Whether the expression is reversed
is_reversed: bool,
}

impl AggregateExprBuilder {
pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) -> Self {
Self {
fun,
args,
logical_args: vec![],
name: String::new(),
schema: Schema::empty(),
dfschema: DFSchema::empty(),
sort_exprs: vec![],
ordering_req: vec![],
ignore_nulls: false,
is_distinct: false,
is_reversed: false,
}
}

pub fn build(self) -> Result<Arc<dyn AggregateExpr>> {
let Self {
fun,
args,
logical_args,
name,
schema,
dfschema,
sort_exprs,
ordering_req,
ignore_nulls,
is_distinct,
is_reversed,
} = self;
if args.is_empty() {
return internal_err!("args should not be empty");
}

let mut ordering_fields = vec![];

debug_assert_eq!(sort_exprs.len(), ordering_req.len());
if !ordering_req.is_empty() {
let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(&schema))
.collect::<Result<Vec<_>>>()?;

ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types);
}

let input_exprs_types = args
.iter()
.map(|arg| arg.data_type(&schema))
.collect::<Result<Vec<_>>>()?;

check_arg_count(
fun.name(),
&input_exprs_types,
&fun.signature().type_signature,
)?;

let data_type = fun.return_type(&input_exprs_types)?;

Ok(Arc::new(AggregateFunctionExpr {
fun: Arc::unwrap_or_clone(fun),
args,
logical_args,
data_type,
name,
schema,
dfschema,
sort_exprs,
ordering_req,
ignore_nulls,
ordering_fields,
is_distinct,
input_type: input_exprs_types[0].clone(),
is_reversed,
}))
}

pub fn name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}

pub fn schema(mut self, schema: Schema) -> Self {
self.schema = schema;
self
}

pub fn dfschema(mut self, dfschema: DFSchema) -> Self {
self.dfschema = dfschema;
self
}

pub fn order_by(mut self, order_by: LexOrdering) -> Self {
self.ordering_req = order_by;
self
}

pub fn reversed(mut self) -> Self {
self.is_reversed = true;
self
}

pub fn distinct(mut self) -> Self {
self.is_distinct = true;
self
}

pub fn ignore_nulls(mut self) -> Self {
self.ignore_nulls = true;
self
}

/// This method will be deprecated in <https://github.com/apache/datafusion/issues/11359>
pub fn sort_exprs(mut self, sort_exprs: Vec<Expr>) -> Self {
self.sort_exprs = sort_exprs;
self
}

/// This method will be deprecated in <https://github.com/apache/datafusion/issues/11359>
pub fn logical_exprs(mut self, logical_args: Vec<Expr>) -> Self {
self.logical_args = logical_args;
self
}
}

/// An aggregate expression that:
Expand Down

0 comments on commit 5d98c32

Please sign in to comment.