Skip to content

Commit

Permalink
move Atan2, Atan, Acosh, Asinh, Atanh to `datafusion-functi…
Browse files Browse the repository at this point in the history
…on` (#9872)

* Refactor math functions in datafusion code

* fic ci

* fix: avoid regression

* refactor: move atan2 function

* chore: move atan2 test
  • Loading branch information
Weijun-H authored Mar 31, 2024
1 parent 2cb6f73 commit 66c8ba2
Show file tree
Hide file tree
Showing 13 changed files with 201 additions and 184 deletions.
48 changes: 4 additions & 44 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,6 @@ use strum_macros::EnumIter;
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)]
pub enum BuiltinScalarFunction {
// math functions
/// atan
Atan,
/// atan2
Atan2,
/// acosh
Acosh,
/// asinh
Asinh,
/// atanh
Atanh,
/// cbrt
Cbrt,
/// ceil
Expand Down Expand Up @@ -159,11 +149,6 @@ impl BuiltinScalarFunction {
pub fn volatility(&self) -> Volatility {
match self {
// Immutable scalar builtins
BuiltinScalarFunction::Atan => Volatility::Immutable,
BuiltinScalarFunction::Atan2 => Volatility::Immutable,
BuiltinScalarFunction::Acosh => Volatility::Immutable,
BuiltinScalarFunction::Asinh => Volatility::Immutable,
BuiltinScalarFunction::Atanh => Volatility::Immutable,
BuiltinScalarFunction::Ceil => Volatility::Immutable,
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
BuiltinScalarFunction::Cos => Volatility::Immutable,
Expand Down Expand Up @@ -238,11 +223,6 @@ impl BuiltinScalarFunction {
_ => Ok(Float64),
},

BuiltinScalarFunction::Atan2 => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
},

BuiltinScalarFunction::Log => match &input_expr_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
Expand All @@ -255,11 +235,7 @@ impl BuiltinScalarFunction {

BuiltinScalarFunction::Iszero => Ok(Boolean),

BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Acosh
| BuiltinScalarFunction::Asinh
| BuiltinScalarFunction::Atanh
| BuiltinScalarFunction::Ceil
BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Cos
| BuiltinScalarFunction::Cosh
| BuiltinScalarFunction::Degrees
Expand Down Expand Up @@ -332,10 +308,7 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
BuiltinScalarFunction::Atan2 => Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
self.volatility(),
),

BuiltinScalarFunction::Log => Signature::one_of(
vec![
Exact(vec![Float32]),
Expand All @@ -355,11 +328,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => {
Signature::uniform(2, vec![Int64], self.volatility())
}
BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Acosh
| BuiltinScalarFunction::Asinh
| BuiltinScalarFunction::Atanh
| BuiltinScalarFunction::Cbrt
BuiltinScalarFunction::Cbrt
| BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Cos
| BuiltinScalarFunction::Cosh
Expand Down Expand Up @@ -392,11 +361,7 @@ impl BuiltinScalarFunction {
pub fn monotonicity(&self) -> Option<FuncMonotonicity> {
if matches!(
&self,
BuiltinScalarFunction::Atan
| BuiltinScalarFunction::Acosh
| BuiltinScalarFunction::Asinh
| BuiltinScalarFunction::Atanh
| BuiltinScalarFunction::Ceil
BuiltinScalarFunction::Ceil
| BuiltinScalarFunction::Degrees
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Factorial
Expand All @@ -421,11 +386,6 @@ impl BuiltinScalarFunction {
/// Returns all names that can be used to call this function
pub fn aliases(&self) -> &'static [&'static str] {
match self {
BuiltinScalarFunction::Acosh => &["acosh"],
BuiltinScalarFunction::Asinh => &["asinh"],
BuiltinScalarFunction::Atan => &["atan"],
BuiltinScalarFunction::Atanh => &["atanh"],
BuiltinScalarFunction::Atan2 => &["atan2"],
BuiltinScalarFunction::Cbrt => &["cbrt"],
BuiltinScalarFunction::Ceil => &["ceil"],
BuiltinScalarFunction::Cos => &["cos"],
Expand Down
10 changes: 0 additions & 10 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,10 +541,6 @@ scalar_expr!(Cos, cos, num, "cosine");
scalar_expr!(Cot, cot, num, "cotangent");
scalar_expr!(Sinh, sinh, num, "hyperbolic sine");
scalar_expr!(Cosh, cosh, num, "hyperbolic cosine");
scalar_expr!(Atan, atan, num, "inverse tangent");
scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine");
scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine");
scalar_expr!(Atanh, atanh, num, "inverse hyperbolic tangent");
scalar_expr!(Factorial, factorial, num, "factorial");
scalar_expr!(
Floor,
Expand All @@ -571,7 +567,6 @@ scalar_expr!(Exp, exp, num, "exponential");
scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor");
scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple");
scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`");
scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument");
scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`");

scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
Expand Down Expand Up @@ -979,10 +974,6 @@ mod test {
test_unary_scalar_expr!(Cot, cot);
test_unary_scalar_expr!(Sinh, sinh);
test_unary_scalar_expr!(Cosh, cosh);
test_unary_scalar_expr!(Atan, atan);
test_unary_scalar_expr!(Asinh, asinh);
test_unary_scalar_expr!(Acosh, acosh);
test_unary_scalar_expr!(Atanh, atanh);
test_unary_scalar_expr!(Factorial, factorial);
test_unary_scalar_expr!(Floor, floor);
test_unary_scalar_expr!(Ceil, ceil);
Expand All @@ -994,7 +985,6 @@ mod test {
test_nary_scalar_expr!(Trunc, trunc, num, precision);
test_unary_scalar_expr!(Signum, signum);
test_unary_scalar_expr!(Exp, exp);
test_scalar_expr!(Atan2, atan2, y, x);
test_scalar_expr!(Nanvl, nanvl, x, y);
test_scalar_expr!(Iszero, iszero, input);

Expand Down
29 changes: 29 additions & 0 deletions datafusion/functions/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ macro_rules! downcast_arg {
/// $GNAME: a singleton instance of the UDF
/// $NAME: the name of the function
/// $UNARY_FUNC: the unary function to apply to the argument
/// $MONOTONIC_FUNC: the monotonicity of the function
macro_rules! make_math_unary_udf {
($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $MONOTONICITY:expr) => {
make_udf_function!($NAME::$UDF, $GNAME, $NAME);
Expand Down Expand Up @@ -249,3 +250,31 @@ macro_rules! make_math_unary_udf {
}
};
}

#[macro_export]
macro_rules! make_function_inputs2 {
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE);
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE);

arg1.iter()
.zip(arg2.iter())
.map(|(a1, a2)| match (a1, a2) {
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
_ => None,
})
.collect::<$ARRAY_TYPE>()
}};
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1);
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2);

arg1.iter()
.zip(arg2.iter())
.map(|(a1, a2)| match (a1, a2) {
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
_ => None,
})
.collect::<$ARRAY_TYPE1>()
}};
}
140 changes: 140 additions & 0 deletions datafusion/functions/src/math/atan2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Math function: `atan2()`.
use arrow::array::{ArrayRef, Float32Array, Float64Array};
use arrow::datatypes::DataType;
use datafusion_common::DataFusionError;
use datafusion_common::{exec_err, Result};
use datafusion_expr::ColumnarValue;
use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;

use crate::make_function_inputs2;
use crate::utils::make_scalar_function;

#[derive(Debug)]
pub(super) struct Atan2 {
signature: Signature,
}

impl Atan2 {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for Atan2 {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"atan2"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use self::DataType::*;
match &arg_types[0] {
Float32 => Ok(Float32),
_ => Ok(Float64),
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
make_scalar_function(atan2, vec![])(args)
}
}

/// Atan2 SQL function
pub fn atan2(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
&args[0],
&args[1],
"y",
"x",
Float64Array,
{ f64::atan2 }
)) as ArrayRef),

DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
&args[0],
&args[1],
"y",
"x",
Float32Array,
{ f32::atan2 }
)) as ArrayRef),

other => exec_err!("Unsupported data type {other:?} for function atan2"),
}
}

#[cfg(test)]
mod test {
use super::*;
use datafusion_common::cast::{as_float32_array, as_float64_array};

#[test]
fn test_atan2_f64() {
let args: Vec<ArrayRef> = vec![
Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y
Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x
];

let result = atan2(&args).expect("failed to initialize function atan2");
let floats =
as_float64_array(&result).expect("failed to initialize function atan2");

assert_eq!(floats.len(), 4);
assert_eq!(floats.value(0), (2.0_f64).atan2(1.0));
assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0));
assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0));
assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0));
}

#[test]
fn test_atan2_f32() {
let args: Vec<ArrayRef> = vec![
Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y
Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x
];

let result = atan2(&args).expect("failed to initialize function atan2");
let floats =
as_float32_array(&result).expect("failed to initialize function atan2");

assert_eq!(floats.len(), 4);
assert_eq!(floats.value(0), (2.0_f32).atan2(1.0));
assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0));
assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0));
assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0));
}
}
14 changes: 13 additions & 1 deletion datafusion/functions/src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
//! "math" DataFusion functions
mod abs;
mod atan2;
mod nans;

// Create UDFs
make_udf_function!(nans::IsNanFunc, ISNAN, isnan);
make_udf_function!(abs::AbsFunc, ABS, abs);
make_udf_function!(atan2::Atan2, ATAN2, atan2);

make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)]));
make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)]));
Expand All @@ -33,6 +35,11 @@ make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None);
make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None);
make_math_unary_udf!(TanFunc, TAN, tan, tan, None);

make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)]));
make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)]));
make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)]));
make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)]));

// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
(
Expand All @@ -55,5 +62,10 @@ export_functions!(
"returns the arc sine or inverse sine of a number"
),
(tan, num, "returns the tangent of a number"),
(tanh, num, "returns the hyperbolic tangent of a number")
(tanh, num, "returns the hyperbolic tangent of a number"),
(atanh, num, "returns inverse hyperbolic tangent"),
(asinh, num, "returns inverse hyperbolic sine"),
(acosh, num, "returns inverse hyperbolic cosine"),
(atan, num, "returns inverse tangent"),
(atan2, y x, "returns inverse tangent of a division given in the argument")
);
3 changes: 3 additions & 0 deletions datafusion/functions/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size.
get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);

/// Creates a scalar function implementation for the given function.
/// * `inner` - the function to be executed
/// * `hints` - hints to be used when expanding scalars to arrays
pub(super) fn make_scalar_function<F>(
inner: F,
hints: Vec<Hint>,
Expand Down
7 changes: 0 additions & 7 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,6 @@ pub fn create_physical_fun(
) -> Result<ScalarFunctionImplementation> {
Ok(match fun {
// math functions
BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan),
BuiltinScalarFunction::Acosh => Arc::new(math_expressions::acosh),
BuiltinScalarFunction::Asinh => Arc::new(math_expressions::asinh),
BuiltinScalarFunction::Atanh => Arc::new(math_expressions::atanh),
BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil),
BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos),
BuiltinScalarFunction::Cosh => Arc::new(math_expressions::cosh),
Expand Down Expand Up @@ -221,9 +217,6 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Power => {
Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args))
}
BuiltinScalarFunction::Atan2 => {
Arc::new(|args| make_scalar_function_inner(math_expressions::atan2)(args))
}
BuiltinScalarFunction::Log => {
Arc::new(|args| make_scalar_function_inner(math_expressions::log)(args))
}
Expand Down
Loading

0 comments on commit 66c8ba2

Please sign in to comment.