Skip to content

Commit

Permalink
Make ArragAgg (not ordered or distinct) into a UDAF
Browse files Browse the repository at this point in the history
  • Loading branch information
eejbyfeldt committed Jun 21, 2024
1 parent 18042fd commit 1b9f54f
Show file tree
Hide file tree
Showing 14 changed files with 331 additions and 243 deletions.
6 changes: 3 additions & 3 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1590,10 +1590,10 @@ mod tests {
use datafusion_common::{Constraint, Constraints};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition,
cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation,
Volatility, WindowFrame, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::count_distinct;
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};

Expand Down
28 changes: 28 additions & 0 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ use datafusion_expr::{
DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan,
WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_functions_aggregate::array_agg::array_agg_udaf;
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr::LexOrdering;
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
Expand Down Expand Up @@ -1854,6 +1855,33 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
== NullTreatment::IgnoreNulls;

let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::AggregateFunction::ArrayAgg,
) if !distinct && order_by.is_none() => {
let sort_exprs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = udaf::create_aggregate_expr(
&array_agg_udaf(),
&physical_args,
args,
&sort_exprs,
&ordering_reqs,
physical_input_schema,
name,
ignore_nulls,
*distinct,
)?;
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::BuiltIn(fun) => {
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{count, sum};
use datafusion_functions_aggregate::expr_fn::{array_agg, count, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, true),
false
true
),])
);

Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,6 @@ pub fn max(expr: Expr) -> Expr {
))
}

/// Create an expression to represent the array_agg() aggregate function
pub fn array_agg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::ArrayAgg,
vec![expr],
false,
None,
None,
None,
))
}

/// Create an expression to represent the avg() aggregate function
pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ path = "src/lib.rs"
[dependencies]
ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-schema = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
Expand Down
271 changes: 271 additions & 0 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
// Licensed to the Apache Software Foundation (ASF) under on
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines physical expressions that can evaluated at runtime during query execution
use arrow::array::ArrayRef;
use arrow::datatypes::DataType;
use arrow_array::Array;
use arrow_schema::Field;

use datafusion_common::cast::as_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr::AggregateFunctionDefinition;
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::AggregateUDFImpl;
use datafusion_expr::Expr;
use datafusion_expr::{Accumulator, Signature, Volatility};
use std::sync::Arc;

make_udaf_expr_and_func!(
ArrayAgg,
array_agg,
expression,
"Computes the nth value",
array_agg_udaf
);

#[derive(Debug)]
/// ARRAY_AGG aggregate expression
pub struct ArrayAgg {
signature: Signature,
alias: Vec<String>,
}

impl Default for ArrayAgg {
fn default() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
alias: vec!["array_agg".to_string()],
}
}
}

impl AggregateUDFImpl for ArrayAgg {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"ARRAY_AGG"
}

fn aliases(&self) -> &[String] {
&self.alias
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::List(Arc::new(Field::new(
"item",
arg_types[0].clone(),
true,
))))
}

fn state_fields(
&self,
args: datafusion_expr::function::StateFieldsArgs,
) -> Result<Vec<Field>> {
Ok(vec![Field::new_list(
format_state_name(args.name, "array_agg"),
Field::new("item", args.input_type.clone(), true),
true,
)])
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?))
}

fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
datafusion_expr::ReversedUDAF::Identical
}

fn simplify(
&self,
) -> Option<datafusion_expr::function::AggregateFunctionSimplification> {
let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| {
if aggregate_function.order_by.is_some() || aggregate_function.distinct {
Ok(Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
datafusion_expr::aggregate_function::AggregateFunction::ArrayAgg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
}))
} else {
Ok(Expr::AggregateFunction(aggregate_function))
}
};

Some(Box::new(simplify))
}
}

#[derive(Debug)]
pub struct ArrayAggAccumulator {
values: Vec<ArrayRef>,
datatype: DataType,
}

impl ArrayAggAccumulator {
/// new array_agg accumulator based on given item data type
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
values: vec![],
datatype: datatype.clone(),
})
}
}

impl Accumulator for ArrayAggAccumulator {
// Append value like Int64Array(1,2,3)
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
assert!(values.len() == 1, "array_agg can only take 1 param!");
let val = values[0].clone();
self.values.push(val);
Ok(())
}

// Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6))
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert!(states.len() == 1, "array_agg states must be singleton!");

let list_arr = as_list_array(&states[0])?;
for arr in list_arr.iter().flatten() {
self.values.push(arr);
}
Ok(())
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn evaluate(&mut self) -> Result<ScalarValue> {
// Transform Vec<ListArr> to ListArr

let element_arrays: Vec<&dyn Array> =
self.values.iter().map(|a| a.as_ref()).collect();

if element_arrays.is_empty() {
let arr = ScalarValue::new_list(&[], &self.datatype);
return Ok(ScalarValue::List(arr));
}

let concated_array = arrow::compute::concat(&element_arrays)?;
let list_array = array_into_list_array(concated_array);

Ok(ScalarValue::List(Arc::new(list_array)))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<ArrayRef>() * self.values.capacity())
+ self
.values
.iter()
.map(|arr| arr.get_array_memory_size())
.sum::<usize>()
+ self.datatype.size()
- std::mem::size_of_val(&self.datatype)
}
}

#[cfg(test)]
mod tests {
use super::*;

use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use datafusion_physical_expr_common::expressions::column::Column;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

#[test]
fn test_array_agg_expr() -> Result<()> {
let data_types = vec![
DataType::UInt32,
DataType::Int32,
DataType::Float32,
DataType::Float64,
DataType::Decimal128(10, 2),
DataType::Utf8,
];
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_aggregate_expr(
&array_agg_udaf(),
&input_phy_exprs[0..1],
&[],
&[],
&[],
&input_schema,
"c1",
false,
false,
)?;
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new_list("c1", Field::new("item", data_type.clone(), true), true,),
result_agg_phy_exprs.field().unwrap()
);

let result_distinct = create_aggregate_expr(
&array_agg_udaf(),
&input_phy_exprs[0..1],
&[],
&[],
&[],
&input_schema,
"c1",
false,
true,
)?;
assert_eq!("c1", result_distinct.name());
assert_eq!(
Field::new_list("c1", Field::new("item", data_type.clone(), true), true,),
result_agg_phy_exprs.field().unwrap()
);
}
Ok(())
}
}
Loading

0 comments on commit 1b9f54f

Please sign in to comment.