Skip to content

Commit

Permalink
chore: extract predicate_functions expressions to folders based on sp…
Browse files Browse the repository at this point in the history
…ark grouping (#1218)

* extract predicate_functions expressions to folders based on spark grouping

* code review changes

---------

Co-authored-by: Andy Grove <[email protected]>
  • Loading branch information
rluvaton and andygrove authored Jan 8, 2025
1 parent c19202c commit fbcf025
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 4 deletions.
4 changes: 2 additions & 2 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub use checkoverflow::CheckOverflow;

mod kernels;
mod list;
mod regexp;
pub mod scalar_funcs;
mod schema_adapter;
mod static_invoke;
Expand All @@ -50,6 +49,8 @@ mod unbound;
pub use unbound::UnboundColumn;
pub mod utils;
pub use normalize_nan::NormalizeNaNAndZero;
mod predicate_funcs;
pub use predicate_funcs::{spark_isnan, RLike};

mod agg_funcs;
mod comet_scalar_funcs;
Expand All @@ -66,7 +67,6 @@ pub use datetime_funcs::*;
pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
pub use list::{ArrayInsert, GetArrayStructFields, ListExtract};
pub use regexp::RLike;
pub use string_funcs::*;
pub use struct_funcs::*;
pub use to_json::ToJson;
Expand Down
70 changes: 70 additions & 0 deletions native/spark-expr/src/predicate_funcs/is_nan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::{Float32Array, Float64Array};
use arrow_array::{Array, BooleanArray};
use arrow_schema::DataType;
use datafusion::physical_plan::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use std::sync::Arc;

/// Spark-compatible `isnan` expression
pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
match is_nan.nulls() {
Some(nulls) => {
let is_not_null = nulls.inner();
ColumnarValue::Array(Arc::new(BooleanArray::new(
is_nan.values() & is_not_null,
None,
)))
}
None => ColumnarValue::Array(Arc::new(is_nan)),
}
}
let value = &args[0];
match value {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Float64 => {
let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
DataType::Float32 => {
let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function isnan",
other,
))),
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
_ => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function isnan",
value.data_type(),
))),
},
}
}
22 changes: 22 additions & 0 deletions native/spark-expr/src/predicate_funcs/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

mod is_nan;
mod rlike;

pub use is_nan::spark_isnan;
pub use rlike::RLike;
File renamed without changes.
81 changes: 79 additions & 2 deletions native/spark-expr/src/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ use arrow::{
ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int64Builder, Int8Array,
},
compute::kernels::numeric::{add, sub},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION};
use arrow_array::builder::IntervalDayTimeBuilder;
use arrow_array::types::{Int16Type, Int32Type, Int8Type, IntervalDayTime};
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array};
use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION};
use datafusion::physical_expr_common::datum;
use datafusion::{functions::math::round::round, physical_plan::ColumnarValue};
use datafusion_common::{
exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue,
Expand Down Expand Up @@ -447,6 +451,79 @@ pub fn spark_decimal_div(
Ok(ColumnarValue::Array(Arc::new(result)))
}

macro_rules! scalar_date_arithmetic {
($start:expr, $days:expr, $op:expr) => {{
let interval = IntervalDayTime::new(*$days as i32, 0);
let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
datum::apply($start, &interval_cv, $op)
}};
}
macro_rules! array_date_arithmetic {
($days:expr, $interval_builder:expr, $intType:ty) => {{
for day in $days.as_primitive::<$intType>().into_iter() {
if let Some(non_null_day) = day {
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
} else {
$interval_builder.append_null();
}
}
}};
}

/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second
/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the
/// second argument and use DataFusion's interface to apply Arrow's operators.
fn spark_date_arithmetic(
args: &[ColumnarValue],
op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
) -> Result<ColumnarValue, DataFusionError> {
let start = &args[0];
match &args[1] {
ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
scalar_date_arithmetic!(start, days, op)
}
ColumnarValue::Array(days) => {
let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len());
match days.data_type() {
DataType::Int8 => {
array_date_arithmetic!(days, interval_builder, Int8Type)
}
DataType::Int16 => {
array_date_arithmetic!(days, interval_builder, Int16Type)
}
DataType::Int32 => {
array_date_arithmetic!(days, interval_builder, Int32Type)
}
_ => {
return Err(DataFusionError::Internal(format!(
"Unsupported data types {:?} for date arithmetic.",
args,
)))
}
}
let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish()));
datum::apply(start, &interval_cv, op)
}
_ => Err(DataFusionError::Internal(format!(
"Unsupported data types {:?} for date arithmetic.",
args,
))),
}
}
pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_date_arithmetic(args, add)
}

pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
spark_date_arithmetic(args, sub)
}

/// Spark-compatible `isnan` expression
pub fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
Expand Down

0 comments on commit fbcf025

Please sign in to comment.