From fbcf0251082a43b5ee25b6c5933a9262cce44071 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 8 Jan 2025 09:10:34 +0200 Subject: [PATCH] chore: extract predicate_functions expressions to folders based on spark grouping (#1218) * extract predicate_functions expressions to folders based on spark grouping * code review changes --------- Co-authored-by: Andy Grove --- native/spark-expr/src/lib.rs | 4 +- .../spark-expr/src/predicate_funcs/is_nan.rs | 70 ++++++++++++++++ native/spark-expr/src/predicate_funcs/mod.rs | 22 +++++ .../{regexp.rs => predicate_funcs/rlike.rs} | 0 native/spark-expr/src/scalar_funcs.rs | 81 ++++++++++++++++++- 5 files changed, 173 insertions(+), 4 deletions(-) create mode 100644 native/spark-expr/src/predicate_funcs/is_nan.rs create mode 100644 native/spark-expr/src/predicate_funcs/mod.rs rename native/spark-expr/src/{regexp.rs => predicate_funcs/rlike.rs} (100%) diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index c7c54a4e9..c614e1f0a 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -30,7 +30,6 @@ pub use checkoverflow::CheckOverflow; mod kernels; mod list; -mod regexp; pub mod scalar_funcs; mod schema_adapter; mod static_invoke; @@ -50,6 +49,8 @@ mod unbound; pub use unbound::UnboundColumn; pub mod utils; pub use normalize_nan::NormalizeNaNAndZero; +mod predicate_funcs; +pub use predicate_funcs::{spark_isnan, RLike}; mod agg_funcs; mod comet_scalar_funcs; @@ -66,7 +67,6 @@ pub use datetime_funcs::*; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; pub use list::{ArrayInsert, GetArrayStructFields, ListExtract}; -pub use regexp::RLike; pub use string_funcs::*; pub use struct_funcs::*; pub use to_json::ToJson; diff --git a/native/spark-expr/src/predicate_funcs/is_nan.rs b/native/spark-expr/src/predicate_funcs/is_nan.rs new file mode 100644 index 000000000..bf4d7e0f2 --- /dev/null +++ b/native/spark-expr/src/predicate_funcs/is_nan.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Float32Array, Float64Array}; +use arrow_array::{Array, BooleanArray}; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::sync::Arc; + +/// Spark-compatible `isnan` expression +pub fn spark_isnan(args: &[ColumnarValue]) -> Result { + fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue { + match is_nan.nulls() { + Some(nulls) => { + let is_not_null = nulls.inner(); + ColumnarValue::Array(Arc::new(BooleanArray::new( + is_nan.values() & is_not_null, + None, + ))) + } + None => ColumnarValue::Array(Arc::new(is_nan)), + } + } + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + a.map(|x| x.is_nan()).unwrap_or(false), + )))), + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + a.map(|x| x.is_nan()).unwrap_or(false), + )))), + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + value.data_type(), + ))), + }, + } +} diff --git a/native/spark-expr/src/predicate_funcs/mod.rs b/native/spark-expr/src/predicate_funcs/mod.rs new file mode 100644 index 000000000..5f1f570c0 --- /dev/null +++ b/native/spark-expr/src/predicate_funcs/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod is_nan; +mod rlike; + +pub use is_nan::spark_isnan; +pub use rlike::RLike; diff --git a/native/spark-expr/src/regexp.rs b/native/spark-expr/src/predicate_funcs/rlike.rs similarity index 100% rename from native/spark-expr/src/regexp.rs rename to native/spark-expr/src/predicate_funcs/rlike.rs diff --git a/native/spark-expr/src/scalar_funcs.rs b/native/spark-expr/src/scalar_funcs.rs index e11d1c5db..9421d54fd 100644 --- a/native/spark-expr/src/scalar_funcs.rs +++ b/native/spark-expr/src/scalar_funcs.rs @@ -20,10 +20,14 @@ use arrow::{ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, }, + compute::kernels::numeric::{add, sub}, datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; -use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array}; -use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION}; +use arrow_array::builder::IntervalDayTimeBuilder; +use arrow_array::types::{Int16Type, Int32Type, Int8Type, IntervalDayTime}; +use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array}; +use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION}; +use datafusion::physical_expr_common::datum; use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; use datafusion_common::{ exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, @@ -447,6 +451,79 @@ pub fn spark_decimal_div( Ok(ColumnarValue::Array(Arc::new(result))) } +macro_rules! scalar_date_arithmetic { + ($start:expr, $days:expr, $op:expr) => {{ + let interval = IntervalDayTime::new(*$days as i32, 0); + let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval))); + datum::apply($start, &interval_cv, $op) + }}; +} +macro_rules! array_date_arithmetic { + ($days:expr, $interval_builder:expr, $intType:ty) => {{ + for day in $days.as_primitive::<$intType>().into_iter() { + if let Some(non_null_day) = day { + $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0)); + } else { + $interval_builder.append_null(); + } + } + }}; +} + +/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second +/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the +/// second argument and use DataFusion's interface to apply Arrow's operators. +fn spark_date_arithmetic( + args: &[ColumnarValue], + op: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + let start = &args[0]; + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Array(days) => { + let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len()); + match days.data_type() { + DataType::Int8 => { + array_date_arithmetic!(days, interval_builder, Int8Type) + } + DataType::Int16 => { + array_date_arithmetic!(days, interval_builder, Int16Type) + } + DataType::Int32 => { + array_date_arithmetic!(days, interval_builder, Int32Type) + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))) + } + } + let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish())); + datum::apply(start, &interval_cv, op) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))), + } +} +pub fn spark_date_add(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, add) +} + +pub fn spark_date_sub(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, sub) +} + /// Spark-compatible `isnan` expression pub fn spark_isnan(args: &[ColumnarValue]) -> Result { fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {