Skip to content

Commit

Permalink
fix udaf macro for distinct but not apply
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed May 16, 2024
1 parent ea81c6e commit d55abb4
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 44 deletions.
3 changes: 1 addition & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
avg, max, median, min, stddev, utils::COUNT_STAR_EXPANSION,
avg, count, max, median, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null, sum};

use async_trait::async_trait;
use datafusion_functions_aggregate::count::count;

/// Contains options that control how data is
/// written out from a DataFrame
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ impl PhysicalOptimizerRule for AggregateStatistics {
.expect("take_optimizable() ensures that this is a AggregateExec");
let stats = partial_agg_exec.input().statistics()?;
let mut projections = vec![];

for expr in partial_agg_exec.aggr_expr() {
if let Some((non_null_rows, name)) =
take_optimizable_column_and_table_count(&**expr, &stats)
Expand Down
1 change: 0 additions & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1894,7 +1894,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
};
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);

let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,7 @@ fn write_name<W: Write>(w: &mut W, e: &Expr) -> Result<()> {
null_treatment,
}) => {
write_function_name(w, &fun.to_string(), false, args)?;

if let Some(nt) = null_treatment {
w.write_str(" ")?;
write!(w, "{}", nt)?;
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ path = "src/lib.rs"

[dependencies]
arrow = { workspace = true }
concat-idents = "1.1.5"
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
Expand Down
20 changes: 3 additions & 17 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ use arrow::{
use datafusion_common::{
downcast_value, internal_err, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::function::{GroupsAccumulatorSupportedArgs, StateFieldsArgs};
use datafusion_expr::Expr;
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
EmitTo, GroupsAccumulator, Signature, Volatility,
Expand All @@ -55,26 +53,14 @@ use datafusion_physical_expr_common::{
binary_map::OutputType,
};

make_udaf_expr_and_func!(
make_distinct_udaf_expr_and_func!(
Count,
count,
expression,
"Returns the number of non-null values in the group.",
count_udaf
);

/// Create an expression to represent the count(distinct) aggregate function
pub fn count_distinct(expression: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new_udf(
count_udaf(),
vec![expression],
true,
None,
None,
None,
))
}

pub struct Count {
signature: Signature,
aliases: Vec<String>,
Expand Down Expand Up @@ -252,7 +238,7 @@ impl AggregateUDFImpl for Count {
&self.aliases
}

fn groups_accumulator_supported(&self, args: GroupsAccumulatorSupportedArgs) -> bool {
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
// groups accumulator only supports `COUNT(c1)`, not
// `COUNT(c1, c2)`, etc
if args.is_distinct {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ use std::sync::Arc;
/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
pub use super::count::count;
pub use super::count::count_distinct;
pub use super::covariance::covar_samp;
pub use super::covariance::covar_pop;
pub use super::first_last::first_value;
}

Expand Down
55 changes: 38 additions & 17 deletions datafusion/functions-aggregate/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,43 +50,64 @@ macro_rules! make_udaf_expr_and_func {
}
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $distinct:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN(
$($arg: datafusion_expr::Expr,)*
args: Vec<datafusion_expr::Expr>,
distinct: bool,
filter: Option<Box<datafusion_expr::Expr>>,
order_by: Option<Vec<datafusion_expr::Expr>>,
null_treatment: Option<sqlparser::ast::NullTreatment>
) -> datafusion_expr::Expr {
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
vec![$($arg),*],
args,
distinct,
None,
None,
None
filter,
order_by,
null_treatment,
))
}
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
}

macro_rules! make_distinct_udaf_expr_and_func {
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN(
args: Vec<datafusion_expr::Expr>,
distinct: bool,
filter: Option<Box<datafusion_expr::Expr>>,
order_by: Option<Vec<datafusion_expr::Expr>>,
null_treatment: Option<sqlparser::ast::NullTreatment>
$($arg: datafusion_expr::Expr,)*
) -> datafusion_expr::Expr {
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
args,
distinct,
filter,
order_by,
null_treatment,
vec![$($arg),*],
false,
None,
None,
None
))
}

// build distinct version of the function
// The name is the same as the original function with `_distinct` appended
concat_idents::concat_idents!(distinct_fn_name = $EXPR_FN, _distinct {
#[doc = $DOC]
pub fn distinct_fn_name(
$($arg: datafusion_expr::Expr,)*
) -> datafusion_expr::Expr {
datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
$AGGREGATE_UDF_FN(),
vec![$($arg),*],
true,
None,
None,
None
))
}
});

create_func!($UDAF, $AGGREGATE_UDF_FN);
};
}
Expand Down
3 changes: 1 addition & 2 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,9 @@ mod tests {
use crate::test::*;
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
use datafusion_expr::count;
use datafusion_expr::expr::Sort;
use datafusion_expr::{
col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max,
col, count, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max,
out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, WindowFrame,
WindowFrameBound, WindowFrameUnits,
};
Expand Down
2 changes: 0 additions & 2 deletions datafusion/optimizer/src/push_down_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ impl OptimizerRule for PushDownLimit {
plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
use std::cmp::min;

let LogicalPlan::Limit(limit) = plan else {
return Ok(None);
};
Expand Down
5 changes: 3 additions & 2 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::execution::FunctionRegistry;
use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp};
use datafusion::functions_aggregate::expr_fn::first_value;
use datafusion::functions_aggregate::expr_fn::{first_value, covar_pop, covar_samp, count};
use datafusion::prelude::*;
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::config::{FormatOptions, TableOptions};
Expand Down Expand Up @@ -625,6 +624,8 @@ async fn roundtrip_expr_api() -> Result<()> {
covar_samp(lit(1.5), lit(2.2)),
covar_pop(lit(1.5), lit(2.2)),
count(lit(1)),
// TODO: Distinct is missing
// count_distinct(lit(2)),
];

// ensure expressions created with the expr api can be round tripped
Expand Down

0 comments on commit d55abb4

Please sign in to comment.