Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move ArrayDims, ArrayNdims and Cardinality to datafusion-function-crate #9425

Merged
merged 3 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,6 @@ pub enum BuiltinScalarFunction {
ArrayPopFront,
/// array_pop_back
ArrayPopBack,
/// array_dims
ArrayDims,
/// array_distinct
ArrayDistinct,
/// array_element
Expand All @@ -140,8 +138,6 @@ pub enum BuiltinScalarFunction {
ArrayEmpty,
/// array_length
ArrayLength,
/// array_ndims
ArrayNdims,
/// array_position
ArrayPosition,
/// array_positions
Expand Down Expand Up @@ -172,8 +168,6 @@ pub enum BuiltinScalarFunction {
ArrayUnion,
/// array_except
ArrayExcept,
/// cardinality
Cardinality,
/// array_resize
ArrayResize,
/// construct an array from columns
Expand Down Expand Up @@ -385,12 +379,10 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable,
BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable,
BuiltinScalarFunction::ArrayHas => Volatility::Immutable,
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable,
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
BuiltinScalarFunction::ArrayNdims => Volatility::Immutable,
BuiltinScalarFunction::ArrayPopFront => Volatility::Immutable,
BuiltinScalarFunction::ArrayPopBack => Volatility::Immutable,
BuiltinScalarFunction::ArrayPosition => Volatility::Immutable,
Expand All @@ -409,7 +401,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
BuiltinScalarFunction::ArrayUnion => Volatility::Immutable,
BuiltinScalarFunction::ArrayResize => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
BuiltinScalarFunction::Ascii => Volatility::Immutable,
BuiltinScalarFunction::BitLength => Volatility::Immutable,
Expand Down Expand Up @@ -561,9 +552,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::ArrayHasAny
| BuiltinScalarFunction::ArrayHas
| BuiltinScalarFunction::ArrayEmpty => Ok(Boolean),
BuiltinScalarFunction::ArrayDims => {
Ok(List(Arc::new(Field::new("item", UInt64, true))))
}
BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
List(field)
Expand All @@ -574,7 +562,6 @@ impl BuiltinScalarFunction {
),
},
BuiltinScalarFunction::ArrayLength => Ok(UInt64),
BuiltinScalarFunction::ArrayNdims => Ok(UInt64),
BuiltinScalarFunction::ArrayPopFront => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayPopBack => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayPosition => Ok(UInt64),
Expand Down Expand Up @@ -622,7 +609,6 @@ impl BuiltinScalarFunction {
(dt, _) => Ok(dt),
}
}
BuiltinScalarFunction::Cardinality => Ok(UInt64),
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
_ => {
Expand Down Expand Up @@ -884,7 +870,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayConcat => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayDims => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayEmpty => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayElement => {
Signature::array_and_index(self.volatility())
Expand All @@ -900,7 +885,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayLength => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayNdims => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayPosition => {
Signature::array_and_element_and_optional_index(self.volatility())
Expand Down Expand Up @@ -931,7 +915,6 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::array(self.volatility()),
BuiltinScalarFunction::ArrayResize => {
Signature::variadic_any(self.volatility())
}
Expand Down Expand Up @@ -1481,7 +1464,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayConcat => {
&["array_concat", "array_cat", "list_concat", "list_cat"]
}
BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"],
BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"],
BuiltinScalarFunction::ArrayEmpty => &["empty"],
BuiltinScalarFunction::ArrayElement => &[
Expand All @@ -1498,7 +1480,6 @@ impl BuiltinScalarFunction {
&["array_has", "list_has", "array_contains", "list_contains"]
}
BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"],
BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"],
BuiltinScalarFunction::ArrayPopFront => {
&["array_pop_front", "list_pop_front"]
}
Expand Down Expand Up @@ -1534,7 +1515,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReverse => &["array_reverse", "list_reverse"],
BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"],
BuiltinScalarFunction::ArrayUnion => &["array_union", "list_union"],
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::ArrayResize => &["array_resize", "list_resize"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::ArrayIntersect => {
Expand Down
22 changes: 0 additions & 22 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -628,12 +628,6 @@ scalar_expr!(
array,
"flattens an array of arrays into a single array."
);
scalar_expr!(
ArrayDims,
array_dims,
array,
"returns an array of the array's dimensions."
);
scalar_expr!(
ArrayElement,
array_element,
Expand All @@ -652,12 +646,6 @@ scalar_expr!(
array dimension,
"returns the length of the array dimension."
);
scalar_expr!(
ArrayNdims,
array_ndims,
array,
"returns the number of dimensions of the array."
);
scalar_expr!(
ArrayDistinct,
array_distinct,
Expand Down Expand Up @@ -738,13 +726,6 @@ scalar_expr!(
);
scalar_expr!(ArrayUnion, array_union, array1 array2, "returns an array of the elements in the union of array1 and array2 without duplicates.");

scalar_expr!(
Cardinality,
cardinality,
array,
"returns the total number of elements in the array."
);

scalar_expr!(
ArrayResize,
array_resize,
Expand Down Expand Up @@ -1389,9 +1370,7 @@ mod test {
test_scalar_expr!(ArraySort, array_sort, array, desc, null_first);
test_scalar_expr!(ArrayPopFront, array_pop_front, array);
test_scalar_expr!(ArrayPopBack, array_pop_back, array);
test_unary_scalar_expr!(ArrayDims, array_dims);
test_scalar_expr!(ArrayLength, array_length, array, dimension);
test_unary_scalar_expr!(ArrayNdims, array_ndims);
test_scalar_expr!(ArrayPosition, array_position, array, element, index);
test_scalar_expr!(ArrayPositions, array_positions, array, element);
test_scalar_expr!(ArrayPrepend, array_prepend, array, element);
Expand All @@ -1402,7 +1381,6 @@ mod test {
test_scalar_expr!(ArrayReplace, array_replace, array, from, to);
test_scalar_expr!(ArrayReplaceN, array_replace_n, array, from, to, max);
test_scalar_expr!(ArrayReplaceAll, array_replace_all, array, from, to);
test_unary_scalar_expr!(Cardinality, cardinality);
test_nary_scalar_expr!(MakeArray, array, input);

test_unary_scalar_expr!(ArrowTypeof, arrow_typeof);
Expand Down
139 changes: 129 additions & 10 deletions datafusion/functions-array/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@

//! implementation kernels for array functions

use arrow::array::ListArray;
use arrow::array::{
Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericListArray,
Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, OffsetSizeTrait,
StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::datatypes::DataType;
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::Field;
use arrow::datatypes::{DataType, UInt64Type};
use datafusion_common::cast::{
as_int64_array, as_large_list_array, as_list_array, as_string_array,
};
use datafusion_common::{exec_err, DataFusionError};
use datafusion_common::{exec_err, DataFusionError, Result};
use std::any::type_name;
use std::sync::Arc;
macro_rules! downcast_arg {
Expand Down Expand Up @@ -102,7 +105,7 @@ macro_rules! call_array_function {
}

/// Array_to_string SQL function
pub(super) fn array_to_string(args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
pub(super) fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("array_to_string expects two or three arguments");
}
Expand Down Expand Up @@ -254,9 +257,6 @@ pub(super) fn array_to_string(args: &[ArrayRef]) -> datafusion_common::Result<Ar
Ok(Arc::new(string_arr))
}

use arrow::array::ListArray;
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::Field;
/// Generates an array of integers from start to stop with a given step.
///
/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values.
Expand All @@ -271,10 +271,7 @@ use arrow::datatypes::Field;
/// gen_range(3) => [0, 1, 2]
/// gen_range(1, 4) => [1, 2, 3]
/// gen_range(1, 7, 2) => [1, 3, 5]
pub fn gen_range(
args: &[ArrayRef],
include_upper: i64,
) -> datafusion_common::Result<ArrayRef> {
pub fn gen_range(args: &[ArrayRef], include_upper: i64) -> Result<ArrayRef> {
let (start_array, stop_array, step_array) = match args.len() {
1 => (None, as_int64_array(&args[0])?, None),
2 => (
Expand Down Expand Up @@ -319,3 +316,125 @@ pub fn gen_range(
)?);
Ok(arr)
}

/// Returns the length of each array dimension
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is just following the existing pattern of array_functions, but I wonder if it would be better to organize the code by function.

For example, we could put the UDF and implementations in datafusion/functions-array/src/dims.rs 🤔

We could definitely do this as a follow on PR

Any thoughts @jayzhan211 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we split two files udf and kernel inside dims?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the difficult part is naming the category name. Like array_append, array_prepend and array_concat. Should we name it mutable? How about pop_front and pop_back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree how to organize the functions is not clear.

We could always just make one file for each function, I suppose. array_append.rs, array_prepend.rs, etc. Though that feels like a bit of an overkill

fn compute_array_dims(arr: Option<ArrayRef>) -> Result<Option<Vec<Option<u64>>>> {
let mut value = match arr {
Some(arr) => arr,
None => return Ok(None),
};
if value.is_empty() {
return Ok(None);
}
let mut res = vec![Some(value.len() as u64)];

loop {
match value.data_type() {
DataType::List(..) => {
value = downcast_arg!(value, ListArray).value(0);
res.push(Some(value.len() as u64));
}
_ => return Ok(Some(res)),
}
}
}

fn generic_list_cardinality<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
) -> Result<ArrayRef> {
let result = array
.iter()
.map(|arr| match compute_array_dims(arr)? {
Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::<u64>())),
None => Ok(None),
})
.collect::<Result<UInt64Array>>()?;
Ok(Arc::new(result) as ArrayRef)
}

/// Cardinality SQL function
pub fn cardinality(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("cardinality expects one argument");
}

match &args[0].data_type() {
DataType::List(_) => {
let list_array = as_list_array(&args[0])?;
generic_list_cardinality::<i32>(list_array)
}
DataType::LargeList(_) => {
let list_array = as_large_list_array(&args[0])?;
generic_list_cardinality::<i64>(list_array)
}
other => {
exec_err!("cardinality does not support type '{:?}'", other)
}
}
}

/// Array_dims SQL function
pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_dims needs one argument");
}

let data = match args[0].data_type() {
DataType::List(_) => {
let array = as_list_array(&args[0])?;
array
.iter()
.map(compute_array_dims)
.collect::<Result<Vec<_>>>()?
}
DataType::LargeList(_) => {
let array = as_large_list_array(&args[0])?;
array
.iter()
.map(compute_array_dims)
.collect::<Result<Vec<_>>>()?
}
array_type => {
return exec_err!("array_dims does not support type '{array_type:?}'");
}
};

let result = ListArray::from_iter_primitive::<UInt64Type, _, _>(data);

Ok(Arc::new(result) as ArrayRef)
}

/// Array_ndims SQL function
pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 1 {
return exec_err!("array_ndims needs one argument");
}

fn general_list_ndims<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
) -> Result<ArrayRef> {
let mut data = Vec::new();
let ndims = datafusion_common::utils::list_ndims(array.data_type());

for arr in array.iter() {
if arr.is_some() {
data.push(Some(ndims))
} else {
data.push(None)
}
}

Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
}
match args[0].data_type() {
DataType::List(_) => {
let array = as_list_array(&args[0])?;
general_list_ndims::<i32>(array)
}
DataType::LargeList(_) => {
let array = as_large_list_array(&args[0])?;
general_list_ndims::<i64>(array)
}
array_type => exec_err!("array_ndims does not support type {array_type:?}"),
}
}
6 changes: 6 additions & 0 deletions datafusion/functions-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ use std::sync::Arc;

/// Fluent-style API for creating `Expr`s
pub mod expr_fn {
pub use super::udf::array_dims;
pub use super::udf::array_ndims;
pub use super::udf::array_to_string;
pub use super::udf::cardinality;
pub use super::udf::gen_series;
pub use super::udf::range;
}
Expand All @@ -50,6 +53,9 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
udf::array_to_string_udf(),
udf::range_udf(),
udf::gen_series_udf(),
udf::array_dims_udf(),
udf::cardinality_udf(),
udf::array_ndims_udf(),
];
functions.into_iter().try_for_each(|udf| {
let existing_udf = registry.register_udf(udf)?;
Expand Down
Loading
Loading