Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert StringAgg to UDAF #10945

Merged
merged 11 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ pub enum AggregateFunction {
BoolAnd,
/// Bool Or
BoolOr,
/// String aggregation
StringAgg,
}

impl AggregateFunction {
Expand All @@ -68,7 +66,6 @@ impl AggregateFunction {
Grouping => "GROUPING",
BoolAnd => "BOOL_AND",
BoolOr => "BOOL_OR",
StringAgg => "STRING_AGG",
}
}
}
Expand All @@ -92,7 +89,6 @@ impl FromStr for AggregateFunction {
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
"nth_value" => AggregateFunction::NthValue,
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
// other
Expand Down Expand Up @@ -146,7 +142,6 @@ impl AggregateFunction {
)))),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
}
}
}
Expand Down Expand Up @@ -195,9 +190,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::StringAgg => {
Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable)
}
}
}
}
Expand Down
26 changes: 0 additions & 26 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,23 +145,6 @@ pub fn coerce_types(
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::StringAgg => {
if !is_string_agg_supported_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}",
agg_fun,
input_types[0]
);
}
if !is_string_agg_supported_arg_type(&input_types[1]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}",
agg_fun,
input_types[1]
);
}
Ok(vec![LargeUtf8, input_types[1].clone()])
}
}
}

Expand Down Expand Up @@ -391,15 +374,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
arg_type.is_integer()
}

/// Return `true` if `arg_type` is of a [`DataType`] that the
/// [`AggregateFunction::StringAgg`] aggregation can operate on.
pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Null
)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub mod approx_median;
pub mod approx_percentile_cont;
pub mod approx_percentile_cont_with_weight;
pub mod bit_and_or_xor;
pub mod string_agg;

use crate::approx_percentile_cont::approx_percentile_cont_udaf;
use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf;
Expand Down Expand Up @@ -138,6 +139,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
approx_distinct::approx_distinct_udaf(),
approx_percentile_cont_udaf(),
approx_percentile_cont_with_weight_udaf(),
string_agg::string_agg_udaf(),
bit_and_or_xor::bit_and_udaf(),
bit_and_or_xor::bit_or_udaf(),
bit_and_or_xor::bit_xor_udaf(),
Expand Down
153 changes: 153 additions & 0 deletions datafusion/functions-aggregate/src/string_agg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function

use arrow::array::ArrayRef;
use arrow_schema::DataType;
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::Result;
use datafusion_common::{not_impl_err, ScalarValue};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Expr, Signature, TypeSignature, Volatility,
};
use std::any::Any;

make_udaf_expr_and_func!(
StringAgg,
string_agg,
expr delimiter,
"Concatenates the values of string expressions and places separator values between them",
string_agg_udaf
);

/// STRING_AGG aggregate expression
#[derive(Debug)]
pub struct StringAgg {
signature: Signature,
}

impl StringAgg {
/// Create a new StringAgg aggregate function
pub fn new() -> Self {
Self {
signature: Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]),
],
Volatility::Immutable,
),
}
}
}

impl Default for StringAgg {
fn default() -> Self {
Self::new()
}
}

impl AggregateUDFImpl for StringAgg {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"string_agg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::LargeUtf8)
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
match &acc_args.input_exprs[1] {
Expr::Literal(ScalarValue::Utf8(Some(delimiter)))
| Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => {
Ok(Box::new(StringAggAccumulator::new(delimiter)))
}
Expr::Literal(ScalarValue::Utf8(None))
| Expr::Literal(ScalarValue::LargeUtf8(None))
| Expr::Literal(ScalarValue::Null) => {
Ok(Box::new(StringAggAccumulator::new("")))
}
_ => not_impl_err!(
"StringAgg not supported for delimiter {}",
&acc_args.input_exprs[1]
),
}
}
}

#[derive(Debug)]
pub(crate) struct StringAggAccumulator {
values: Option<String>,
delimiter: String,
}

impl StringAggAccumulator {
pub fn new(delimiter: &str) -> Self {
Self {
values: None,
delimiter: delimiter.to_string(),
}
}
}

impl Accumulator for StringAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
.iter()
.filter_map(|v| v.as_ref().map(ToString::to_string))
.collect();
if !string_array.is_empty() {
let s = string_array.join(self.delimiter.as_str());
let v = self.values.get_or_insert("".to_string());
if !v.is_empty() {
v.push_str(self.delimiter.as_str());
}
v.push_str(s.as_str());
}
Ok(())
}

fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
self.update_batch(values)?;
Ok(())
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::LargeUtf8(self.values.clone()))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
+ self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
+ self.delimiter.capacity()
}
}
16 changes: 0 additions & 16 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,6 @@ pub fn create_aggregate_expr(
ordering_req.to_vec(),
))
}
(AggregateFunction::StringAgg, false) => {
if !ordering_req.is_empty() {
return not_impl_err!(
"STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available"
);
}
Arc::new(expressions::StringAgg::new(
input_phy_exprs[0].clone(),
input_phy_exprs[1].clone(),
name,
data_type,
))
}
(AggregateFunction::StringAgg, true) => {
return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available");
}
})
}

Expand Down
1 change: 0 additions & 1 deletion datafusion/physical-expr/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pub(crate) mod correlation;
pub(crate) mod covariance;
pub(crate) mod grouping;
pub(crate) mod nth_value;
pub(crate) mod string_agg;
#[macro_use]
pub(crate) mod min_max;
pub(crate) mod groups_accumulator;
Expand Down
Loading