Skip to content

Commit

Permalink
rm function
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Jun 13, 2024
1 parent 2174200 commit 0579e92
Show file tree
Hide file tree
Showing 21 changed files with 82 additions and 177 deletions.
6 changes: 0 additions & 6 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ use strum_macros::EnumIter;
// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
/// Count
Count,
/// Minimum
Min,
/// Maximum
Expand Down Expand Up @@ -89,7 +87,6 @@ impl AggregateFunction {
pub fn name(&self) -> &str {
use AggregateFunction::*;
match self {
Count => "COUNT",
Min => "MIN",
Max => "MAX",
Avg => "AVG",
Expand Down Expand Up @@ -135,7 +132,6 @@ impl FromStr for AggregateFunction {
"bit_xor" => AggregateFunction::BitXor,
"bool_and" => AggregateFunction::BoolAnd,
"bool_or" => AggregateFunction::BoolOr,
"count" => AggregateFunction::Count,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"min" => AggregateFunction::Min,
Expand Down Expand Up @@ -190,7 +186,6 @@ impl AggregateFunction {
})?;

match self {
AggregateFunction::Count => Ok(DataType::Int64),
AggregateFunction::Max | AggregateFunction::Min => {
// For min and max agg function, the returned type is same as input type.
// The coerced_data_types is same with input_types.
Expand Down Expand Up @@ -249,7 +244,6 @@ impl AggregateFunction {
pub fn signature(&self) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match self {
AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable),
AggregateFunction::Grouping | AggregateFunction::ArrayAgg => {
Signature::any(1, Volatility::Immutable)
}
Expand Down
13 changes: 0 additions & 13 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2135,18 +2135,6 @@ mod test {

use super::*;

#[test]
fn test_count_return_type() -> Result<()> {
let fun = find_df_window_func("count").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
assert_eq!(DataType::Int64, observed);

let observed = fun.return_type(&[DataType::UInt64])?;
assert_eq!(DataType::Int64, observed);

Ok(())
}

#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
Expand Down Expand Up @@ -2250,7 +2238,6 @@ mod test {
"nth_value",
"min",
"max",
"count",
"avg",
];
for name in names {
Expand Down
2 changes: 0 additions & 2 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ pub fn coerce_types(
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::Count => Ok(input_types.to_vec()),
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Min | AggregateFunction::Max => {
// min and max support the dictionary data type
Expand Down Expand Up @@ -525,7 +524,6 @@ mod tests {
// test count, array_agg, approx_distinct, min, max.
// the coerced types is same with input types
let funs = vec![
AggregateFunction::Count,
AggregateFunction::ArrayAgg,
AggregateFunction::Min,
AggregateFunction::Max,
Expand Down
37 changes: 9 additions & 28 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ use datafusion_expr::expr::{
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{
aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition,
};
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};

/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
Expand Down Expand Up @@ -56,37 +54,19 @@ fn is_wildcard(expr: &Expr) -> bool {
}

fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
match aggregate_function {
matches!(aggregate_function,
AggregateFunction {
func_def: AggregateFunctionDefinition::UDF(udf),
args,
..
} if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true,
AggregateFunction {
func_def:
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::Count,
),
args,
..
} if args.len() == 1 && is_wildcard(&args[0]) => true,
_ => false,
}
} if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]))
}

fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
let args = &window_function.args;
match window_function.fun {
WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Count,
) if args.len() == 1 && is_wildcard(&args[0]) => true,
matches!(window_function.fun,
WindowFunctionDefinition::AggregateUDF(ref udaf)
if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) =>
{
true
}
_ => false,
}
if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]))
}

fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
Expand Down Expand Up @@ -123,9 +103,10 @@ mod tests {
use datafusion_expr::expr::Sort;
use datafusion_expr::{
col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max,
out_ref_col, scalar_subquery, wildcard, AggregateFunction, WindowFrame,
WindowFrameBound, WindowFrameUnits,
out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound,
WindowFrameUnits,
};
use datafusion_functions_aggregate::count::count_udaf;
use std::sync::Arc;

use datafusion_functions_aggregate::expr_fn::{count, sum};
Expand Down Expand Up @@ -240,7 +221,7 @@ mod tests {

let plan = LogicalPlanBuilder::from(table_scan)
.window(vec![Expr::WindowFunction(expr::WindowFunction::new(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count),
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
vec![],
vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))],
Expand Down
10 changes: 2 additions & 8 deletions datafusion/optimizer/src/decorrelate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,8 @@ fn agg_exprs_evaluation_result_on_empty_batch(
Expr::AggregateFunction(expr::AggregateFunction {
func_def, ..
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
if matches!(fun, datafusion_expr::AggregateFunction::Count) {
Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(
0,
))))
} else {
Transformed::yes(Expr::Literal(ScalarValue::Null))
}
AggregateFunctionDefinition::BuiltIn(_fun) => {
Transformed::yes(Expr::Literal(ScalarValue::Null))
}
AggregateFunctionDefinition::UDF(fun) => {
if fun.name() == "COUNT" {
Expand Down
16 changes: 6 additions & 10 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ mod tests {
use datafusion_common::{
Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
};
use datafusion_expr::AggregateExt;
use datafusion_expr::{
binary_expr, build_join_schema,
builder::table_scan_with_filters,
Expand All @@ -830,6 +831,7 @@ mod tests {
WindowFunctionDefinition,
};

use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::expr_fn::count;

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
Expand Down Expand Up @@ -1888,16 +1890,10 @@ mod tests {
#[test]
fn aggregate_filter_pushdown() -> Result<()> {
let table_scan = test_table_scan()?;

let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("b")],
false,
Some(Box::new(col("c").gt(lit(42)))),
None,
None,
));

let aggr_with_filter = count_udaf()
.call(vec![col("b")])
.filter(col("c").gt(lit(42)))
.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
Expand Down
46 changes: 20 additions & 26 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,11 @@ mod tests {
use super::*;
use crate::test::*;
use datafusion_expr::expr::{self, GroupingSet};
use datafusion_expr::AggregateExt;
use datafusion_expr::{
lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum};
use datafusion_functions_aggregate::sum::sum_udaf;

Expand Down Expand Up @@ -679,14 +681,11 @@ mod tests {
let table_scan = test_table_scan()?;

// COUNT(DISTINCT a) FILTER (WHERE a > 5)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("a")],
true,
Some(Box::new(col("a").gt(lit(5)))),
None,
None,
));
let expr = count_udaf()
.call(vec![col("a")])
.distinct()
.filter(col("a").gt(lit(5)))
.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
Expand Down Expand Up @@ -725,19 +724,16 @@ mod tests {
let table_scan = test_table_scan()?;

// COUNT(DISTINCT a ORDER BY a)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("a")],
true,
None,
Some(vec![col("a")]),
None,
));
let expr = count_udaf()
.call(vec![col("a")])
.distinct()
.order_by(vec![col("a").sort(true, false)])
.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand All @@ -748,19 +744,17 @@ mod tests {
let table_scan = test_table_scan()?;

// COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("a")],
true,
Some(Box::new(col("a").gt(lit(5)))),
Some(vec![col("a")]),
None,
));
let expr = count_udaf()
.call(vec![col("a")])
.distinct()
.filter(col("a").gt(lit(5)))
.order_by(vec![col("a").sort(true, false)])
.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(plan, expected)
Expand Down
6 changes: 4 additions & 2 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::{plan_err, Result};
use datafusion_expr::test::function_stub::sum_udaf;
use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_optimizer::analyzer::Analyzer;
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule};
Expand Down Expand Up @@ -323,7 +324,7 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
let ast: Vec<Statement> = Parser::parse_sql(&dialect, sql).unwrap();
let statement = &ast[0];
let context_provider = MyContextProvider::default().with_udaf(sum_udaf());
let context_provider = MyContextProvider::default().with_udaf(sum_udaf()).with_udaf(count_udaf());
let sql_to_rel = SqlToRel::new(&context_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();

Expand All @@ -345,7 +346,8 @@ struct MyContextProvider {

impl MyContextProvider {
fn with_udaf(mut self, udaf: Arc<AggregateUDF>) -> Self {
self.udafs.insert(udaf.name().to_string(), udaf);
// TODO: change to to_string() if all the function name is converted to lowercase
self.udafs.insert(udaf.name().to_lowercase(), udaf);
self
}
}
Expand Down
19 changes: 1 addition & 18 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::sync::Arc;

use arrow::datatypes::Schema;

use datafusion_common::{exec_err, internal_err, not_impl_err, Result};
use datafusion_common::{exec_err, not_impl_err, Result};
use datafusion_expr::AggregateFunction;

use crate::aggregate::average::Avg;
Expand Down Expand Up @@ -61,9 +61,6 @@ pub fn create_aggregate_expr(
.collect::<Result<Vec<_>>>()?;
let input_phy_exprs = input_phy_exprs.to_vec();
Ok(match (fun, distinct) {
(AggregateFunction::Count, _) => {
return internal_err!("Builtin Count will be removed");
}
(AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new(
input_phy_exprs[0].clone(),
name,
Expand Down Expand Up @@ -642,20 +639,6 @@ mod tests {
Ok(())
}

#[test]
fn test_count_return_type() -> Result<()> {
let observed = AggregateFunction::Count.return_type(&[DataType::Utf8])?;
assert_eq!(DataType::Int64, observed);

let observed = AggregateFunction::Count.return_type(&[DataType::Int8])?;
assert_eq!(DataType::Int64, observed);

let observed =
AggregateFunction::Count.return_type(&[DataType::Decimal128(28, 13)])?;
assert_eq!(DataType::Int64, observed);
Ok(())
}

#[test]
fn test_avg_return_type() -> Result<()> {
let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?;
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ serde_json = { workspace = true, optional = true }

[dev-dependencies]
datafusion-functions = { workspace = true, default-features = true }
datafusion-functions-aggregate = { workspace = true }
doc-comment = { workspace = true }
strum = { version = "0.26.1", features = ["derive"] }
tokio = { workspace = true, features = ["rt-multi-thread"] }
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ enum AggregateFunction {
MAX = 1;
// SUM = 2;
AVG = 3;
COUNT = 4;
// COUNT = 4;
// APPROX_DISTINCT = 5;
ARRAY_AGG = 6;
// VARIANCE = 7;
Expand Down
Loading

0 comments on commit 0579e92

Please sign in to comment.