Skip to content

Commit

Permalink
move ArrayDims, ArrayNdims and Cardinality to datafusion-function-cra…
Browse files Browse the repository at this point in the history
…te (#9425)

* Update array functions and remove ArrayDims and Cardinality

* move ArrayNdims function

* add roundtrip tests
  • Loading branch information
Weijun-H authored Mar 3, 2024
1 parent 89aea0a commit f229dcc
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 228 deletions.
20 changes: 0 additions & 20 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ pub enum BuiltinScalarFunction {
ArrayPopFront,
/// array_pop_back
ArrayPopBack,
/// array_dims
ArrayDims,
/// array_distinct
ArrayDistinct,
/// array_element
Expand All @@ -140,8 +138,6 @@ pub enum BuiltinScalarFunction {
ArrayEmpty,
/// array_length
ArrayLength,
/// array_ndims
ArrayNdims,
/// array_position
ArrayPosition,
/// array_positions
Expand Down Expand Up @@ -172,8 +168,6 @@ pub enum BuiltinScalarFunction {
ArrayUnion,
/// array_except
ArrayExcept,
/// cardinality
Cardinality,
/// array_resize
ArrayResize,
/// construct an array from columns
Expand Down Expand Up @@ -385,12 +379,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable,
BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable,
BuiltinScalarFunction::ArrayHas => Volatility::Immutable,
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
BuiltinScalarFunction::ArrayNdims => Volatility::Immutable,
BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable,
BuiltinScalarFunction::ArrayPopBack => Volatility::Immutable,
BuiltinScalarFunction::ArrayPosition => Volatility::Immutable,
Expand All @@ -409,7 +401,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
BuiltinScalarFunction::ArrayUnion => Volatility::Immutable,
BuiltinScalarFunction::ArrayResize => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
BuiltinScalarFunction::BitLength => Volatility::Immutable,
Expand Down Expand Up @@ -561,9 +552,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::ArrayHasAny
| BuiltinScalarFunction::ArrayHas
| BuiltinScalarFunction::ArrayEmpty => Ok(Boolean),
BuiltinScalarFunction::ArrayDims => {
Ok(List(Arc::new(Field::new("item", UInt64, true))))
}
BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
List(field)
Expand All @@ -574,7 +562,6 @@ impl BuiltinScalarFunction {
),
},
BuiltinScalarFunction::ArrayLength => Ok(UInt64),
BuiltinScalarFunction::ArrayNdims => Ok(UInt64),
BuiltinScalarFunction::ArrayPopFront => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayPopBack => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayPosition => Ok(UInt64),
Expand Down Expand Up @@ -622,7 +609,6 @@ impl BuiltinScalarFunction {
(dt, _) => Ok(dt),
}
}
BuiltinScalarFunction::Cardinality => Ok(UInt64),
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
_ => {
Expand Down Expand Up @@ -884,7 +870,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayDims => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayEmpty => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayElement => {
Signature::array_and_index(self.volatility())
Expand All @@ -900,7 +885,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayNdims => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPosition => {
Signature::array_and_element_and_optional_index(self.volatility())
Expand Down Expand Up @@ -931,7 +915,6 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayResize => {
Signature::variadic_any(self.volatility())
}
Expand Down Expand Up @@ -1481,7 +1464,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayConcat => {
&["array_concat", "array_cat", "list_concat", "list_cat"]
}
BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"],
BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"],
BuiltinScalarFunction::ArrayEmpty => &["empty"],
BuiltinScalarFunction::ArrayElement => &[
Expand All @@ -1498,7 +1480,6 @@ impl BuiltinScalarFunction {
&["array_has", "list_has", "array_contains", "list_contains"]
}
BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"],
BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"],
BuiltinScalarFunction::ArrayPopFront => {
&["array_pop_front", "list_pop_front"]
}
Expand Down Expand Up @@ -1534,7 +1515,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReverse => &["array_reverse", "list_reverse"],
BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"],
BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"],
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::ArrayResize => &["array_resize", "list_resize"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => {
Expand Down
22 changes: 0 additions & 22 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -628,12 +628,6 @@ scalar_expr!(
array,
"flattens an array of arrays into a single array."
);
scalar_expr!(
ArrayDims,
array_dims,
array,
"returns an array of the array's dimensions."
);
scalar_expr!(
ArrayElement,
array_element,
Expand All @@ -652,12 +646,6 @@ scalar_expr!(
array dimension,
"returns the length of the array dimension."
);
scalar_expr!(
ArrayNdims,
array_ndims,
array,
"returns the number of dimensions of the array."
);
scalar_expr!(
ArrayDistinct,
array_distinct,
Expand Down Expand Up @@ -738,13 +726,6 @@ scalar_expr!(
);
scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates.");

scalar_expr!(
Cardinality,
cardinality,
array,
"returns the total number of elements in the array."
);

scalar_expr!(
ArrayResize,
array_resize,
Expand Down Expand Up @@ -1389,9 +1370,7 @@ mod test {
test_scalar_expr!(ArraySort, array_sort, array, desc, null_first);
test_scalar_expr!(ArrayPopFront, array_pop_front, array);
test_scalar_expr!(ArrayPopBack, array_pop_back, array);
test_unary_scalar_expr!(ArrayDims, array_dims);
test_scalar_expr!(ArrayLength, array_length, array, dimension);
test_unary_scalar_expr!(ArrayNdims, array_ndims);
test_scalar_expr!(ArrayPosition, array_position, array, element, index);
test_scalar_expr!(ArrayPositions, array_positions, array, element);
test_scalar_expr!(ArrayPrepend, array_prepend, array, element);
Expand All @@ -1402,7 +1381,6 @@ mod test {
test_scalar_expr!(ArrayReplace, array_replace, array, from, to);
test_scalar_expr!(ArrayReplaceN, array_replace_n, array, from, to, max);
test_scalar_expr!(ArrayReplaceAll, array_replace_all, array, from, to);
test_unary_scalar_expr!(Cardinality, cardinality);
test_nary_scalar_expr!(MakeArray, array, input);

test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
Expand Down
139 changes: 129 additions & 10 deletions datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

//! implementation kernels for array functions
use arrow::array::ListArray;
use arrow::array::{
Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericListArray,
Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, OffsetSizeTrait,
StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::datatypes::DataType;
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::Field;
use arrow::datatypes::{DataType, UInt64Type};
use datafusion_common::cast::{
as_int64_array, as_large_list_array, as_list_array, as_string_array,
};
use datafusion_common::{exec_err, DataFusionError};
use datafusion_common::{exec_err, DataFusionError, Result};
use std::any::type_name;
use std::sync::Arc;
macro_rules! downcast_arg {
Expand Down Expand Up @@ -102,7 +105,7 @@ macro_rules! call_array_function {
}

/// Array_to_string SQL function
pub(super) fn array_to_string(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
pub(super) fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("array_to_string expects two or three arguments");
}
Expand Down Expand Up @@ -254,9 +257,6 @@ pub(super) fn array_to_string(args: &[ArrayRef]) -> datafusion_common::Result<Ar
Ok(Arc::new(string_arr))
}

use arrow::array::ListArray;
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::Field;
/// Generates an array of integers from start to stop with a given step.
///
/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values.
Expand All @@ -271,10 +271,7 @@ use arrow::datatypes::Field;
/// gen_range(3) => [0, 1, 2]
/// gen_range(1, 4) => [1, 2, 3]
/// gen_range(1, 7, 2) => [1, 3, 5]
pub fn gen_range(
args: &[ArrayRef],
include_upper: i64,
) -> datafusion_common::Result<ArrayRef> {
pub fn gen_range(args: &[ArrayRef], include_upper: i64) -> Result<ArrayRef> {
let (start_array, stop_array, step_array) = match args.len() {
1 => (None, as_int64_array(&args[0])?, None),
2 => (
Expand Down Expand Up @@ -319,3 +316,125 @@ pub fn gen_range(
)?);
Ok(arr)
}

/// Returns the length of each array dimension
fn compute_array_dims(arr: Option<ArrayRef>) -> Result<Option<Vec<Option<u64>>>> {
let mut value = match arr {
Some(arr) => arr,
None => return Ok(None),
};
if value.is_empty() {
return Ok(None);
}
let mut res = vec![Some(value.len() as u64)];

loop {
match value.data_type() {
DataType::List(..) => {
value = downcast_arg!(value, ListArray).value(0);
res.push(Some(value.len() as u64));
}
_ => return Ok(Some(res)),
}
}
}

fn generic_list_cardinality<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
) -> Result<ArrayRef> {
let result = array
.iter()
.map(|arr| match compute_array_dims(arr)? {
Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::<u64>())),
None => Ok(None),
})
.collect::<Result<UInt64Array>>()?;
Ok(Arc::new(result) as ArrayRef)
}

/// Cardinality SQL function
pub fn cardinality(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("cardinality expects one argument");
}

match &args[0].data_type() {
DataType::List(_) => {
let list_array = as_list_array(&args[0])?;
generic_list_cardinality::<i32>(list_array)
}
DataType::LargeList(_) => {
let list_array = as_large_list_array(&args[0])?;
generic_list_cardinality::<i64>(list_array)
}
other => {
exec_err!("cardinality does not support type '{:?}'", other)
}
}
}

/// Array_dims SQL function
pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_dims needs one argument");
}

let data = match args[0].data_type() {
DataType::List(_) => {
let array = as_list_array(&args[0])?;
array
.iter()
.map(compute_array_dims)
.collect::<Result<Vec<_>>>()?
}
DataType::LargeList(_) => {
let array = as_large_list_array(&args[0])?;
array
.iter()
.map(compute_array_dims)
.collect::<Result<Vec<_>>>()?
}
array_type => {
return exec_err!("array_dims does not support type '{array_type:?}'");
}
};

let result = ListArray::from_iter_primitive::<UInt64Type, _, _>(data);

Ok(Arc::new(result) as ArrayRef)
}

/// Array_ndims SQL function
pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_ndims needs one argument");
}

fn general_list_ndims<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
) -> Result<ArrayRef> {
let mut data = Vec::new();
let ndims = datafusion_common::utils::list_ndims(array.data_type());

for arr in array.iter() {
if arr.is_some() {
data.push(Some(ndims))
} else {
data.push(None)
}
}

Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
}
match args[0].data_type() {
DataType::List(_) => {
let array = as_list_array(&args[0])?;
general_list_ndims::<i32>(array)
}
DataType::LargeList(_) => {
let array = as_large_list_array(&args[0])?;
general_list_ndims::<i64>(array)
}
array_type => exec_err!("array_ndims does not support type {array_type:?}"),
}
}
6 changes: 6 additions & 0 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ use std::sync::Arc;

/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
pub use super::udf::array_dims;
pub use super::udf::array_ndims;
pub use super::udf::array_to_string;
pub use super::udf::cardinality;
pub use super::udf::gen_series;
pub use super::udf::range;
}
Expand All @@ -50,6 +53,9 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
udf::array_to_string_udf(),
udf::range_udf(),
udf::gen_series_udf(),
udf::array_dims_udf(),
udf::cardinality_udf(),
udf::array_ndims_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
Expand Down
Loading

0 comments on commit f229dcc

Please sign in to comment.