From 47ab2535e85766a3c5ba025b0b569f1de9ea5bd9 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 22 Jun 2024 10:32:47 +0200 Subject: [PATCH] Address PR comments --- datafusion/core/tests/sql/aggregates.rs | 2 +- datafusion/functions-aggregate/Cargo.toml | 1 - .../functions-aggregate/src/array_agg.rs | 75 +------------------ .../tests/cases/roundtrip_logical_plan.rs | 4 +- 4 files changed, 7 insertions(+), 75 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 3acf5f814984..84b791a3de05 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, true), - true + false ),]) ); diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 3331701844b4..26630a0352d5 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -40,7 +40,6 @@ path = "src/lib.rs" [dependencies] ahash = { workspace = true } arrow = { workspace = true } -arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 27e3c11049f2..a0cedf5817ff 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -1,4 +1,4 @@ -// Licensed to the Apache Software Foundation (ASF) under on +// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -17,9 +17,8 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::DataType; -use arrow_array::Array; use arrow_schema::Field; use datafusion_common::cast::as_list_array; @@ -40,7 +39,7 @@ make_udaf_expr_and_func!( ArrayAgg, array_agg, expression, - "Computes the nth value", + "input values, including nulls, concatenated into an array", array_agg_udaf ); @@ -92,7 +91,7 @@ impl AggregateUDFImpl for ArrayAgg { Ok(vec![Field::new_list( format_state_name(args.name, "array_agg"), Field::new("item", args.input_type.clone(), true), - true, + args.input_nullable, )]) } @@ -203,69 +202,3 @@ impl Accumulator for ArrayAggAccumulator { - std::mem::size_of_val(&self.datatype) } } - -#[cfg(test)] -mod tests { - use super::*; - - use std::sync::Arc; - - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::Result; - use datafusion_physical_expr_common::aggregate::create_aggregate_expr; - use datafusion_physical_expr_common::expressions::column::Column; - use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - - #[test] - fn test_array_agg_expr() -> Result<()> { - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_aggregate_expr( - &array_agg_udaf(), - &input_phy_exprs[0..1], - &[], - &[], - &[], - &input_schema, - "c1", - false, - false, - )?; - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), - result_agg_phy_exprs.field().unwrap() - ); - - let result_distinct = create_aggregate_expr( - &array_agg_udaf(), - &input_phy_exprs[0..1], - &[], - &[], - &[], - &input_schema, - "c1", - false, - true, - )?; - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), - result_agg_phy_exprs.field().unwrap() - ); - } - Ok(()) - } -} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index cf842453822d..95e75f825cfd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,7 +60,7 @@ use datafusion_expr::{ WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::expr_fn::{ - array_agg, bit_and, bit_or, bit_xor, bool_and, bool_or + array_agg, bit_and, bit_or, bit_xor, bool_and, bool_or, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -675,7 +675,7 @@ async fn roundtrip_expr_api() -> Result<()> { string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), bool_and(lit(true)), bool_or(lit(true)), - array_agg(lit(1)) + array_agg(lit(1)), ]; // ensure expressions created with the expr api can be round tripped