diff --git a/src/udf.rs b/src/udf.rs index 4570e77a..e7ddb825 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -15,67 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; +use std::any::Any; use pyo3::{prelude::*, types::PyTuple}; -use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; +use datafusion::arrow::array::{make_array, Array, ArrayData}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::FromPyArrow; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; +use datafusion::common::Result; use datafusion::error::DataFusionError; -use datafusion::logical_expr::function::ScalarFunctionImplementation; -use datafusion::logical_expr::ScalarUDF; -use datafusion::logical_expr::{create_udf, ColumnarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Volatility}; +use datafusion::logical_expr::{ScalarUDF, Signature}; +use std::fmt::Debug; use crate::expr::PyExpr; use crate::utils::parse_volatility; -/// Create a Rust callable function from a python function that expects pyarrow arrays -fn pyarrow_function_to_rust( - func: PyObject, -) -> impl Fn(&[ArrayRef]) -> Result { - move |args: &[ArrayRef]| -> Result { - Python::with_gil(|py| { - // 1. cast args to Pyarrow arrays - let py_args = args - .iter() - .map(|arg| { - arg.into_data() - .to_pyarrow(py) - .map_err(|e| DataFusionError::Execution(format!("{e:?}"))) - }) - .collect::, _>>()?; - let py_args = PyTuple::new_bound(py, py_args); - - // 2. call function - let value = func - .call_bound(py, py_args, None) - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; - - // 3. cast to arrow::array::Array - let array_data = ArrayData::from_pyarrow_bound(value.bind(py)) - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; - Ok(make_array(array_data)) - }) - } -} - -/// Create a DataFusion's UDF implementation from a python function -/// that expects pyarrow arrays. This is more efficient as it performs -/// a zero-copy of the contents. -fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation { - // Make the python function callable from rust - let pyarrow_func = pyarrow_function_to_rust(func); - - // Convert input/output from datafusion ColumnarValue to arrow arrays - Arc::new(move |args: &[ColumnarValue]| { - let array_refs = ColumnarValue::values_to_arrays(args)?; - let array_result = pyarrow_func(&array_refs)?; - Ok(array_result.into()) - }) -} - /// Represents a PyScalarUDF #[pyclass(name = "ScalarUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] @@ -94,14 +50,17 @@ impl PyScalarUDF { return_type: PyArrowType, volatility: &str, ) -> PyResult { - let function = create_udf( + let function = PythonUDF::new( name, input_types.0, return_type.0, parse_volatility(volatility)?, - to_scalar_function_impl(func), + func, ); - Ok(Self { function }) + + Ok(Self { + function: function.into(), + }) } /// creates a new PyExpr with the call of the udf @@ -115,3 +74,111 @@ impl PyScalarUDF { Ok(format!("ScalarUDF({})", self.function.name())) } } + +/// Implements [`ScalarUDFImpl`] for functions that have a single signature and +/// return type. +pub struct PythonUDF { + pub name: String, + pub signature: Signature, + // input types preserved as its a bit messy to get them from signature + pub input_types: Vec, + pub return_type: DataType, + pub func: PyObject, +} + +impl Debug for PythonUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("PythonUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("input_types", &self.input_types) + .field("return_type", &self.return_type) + .field("func", &"") + .finish() + } +} + +impl PythonUDF { + /// Create a new `PythonUDF` from a name, input types, return type and + /// implementation. + pub fn new( + name: impl Into, + input_types: Vec, + return_type: DataType, + volatility: Volatility, + func: PyObject, + ) -> Self { + Self::new_with_signature( + name, + Signature::exact(input_types.clone(), volatility), + input_types, + return_type, + func, + ) + } + + /// Create a new `SimpleScalarUDF` from a name, signature, return type and + /// implementation. + pub fn new_with_signature( + name: impl Into, + signature: Signature, + input_types: Vec, + return_type: DataType, + + func: PyObject, + ) -> Self { + Self { + name: name.into(), + signature, + input_types, + return_type, + func, + } + } +} + +impl ScalarUDFImpl for PythonUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke_batch(&self, args: &[ColumnarValue], _number_rows: usize) -> Result { + let array_refs = ColumnarValue::values_to_arrays(args)?; + let array_data: Result<_> = Python::with_gil(|py| { + // 1. cast args to PyArrow arrays + let py_args = array_refs + .iter() + .map(|arg| { + arg.into_data() + .to_pyarrow(py) + .map_err(|e| DataFusionError::Execution(format!("{e:?}"))) + }) + .collect::, _>>()?; + let py_args = PyTuple::new_bound(py, py_args); + + // 2. call function + let value = self + .func + .call_bound(py, py_args, None) + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + // 3. cast to arrow::array::Array + ArrayData::from_pyarrow_bound(value.bind(py)) + .map_err(|e| DataFusionError::Execution(format!("{e:?}"))) + }); + + Ok(make_array(array_data?).into()) + } +}