From 4e06013986b305581ede48ccae58207dae83a290 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 24 Jan 2024 13:12:09 -0600 Subject: [PATCH 01/17] ScalarValue return types from argument values --- datafusion-examples/examples/complex_udf.rs | 170 ++++++++++++++++++++ datafusion/expr/src/udf.rs | 30 +++- 2 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 datafusion-examples/examples/complex_udf.rs diff --git a/datafusion-examples/examples/complex_udf.rs b/datafusion-examples/examples/complex_udf.rs new file mode 100644 index 000000000000..ac18fbff3dd4 --- /dev/null +++ b/datafusion-examples/examples/complex_udf.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow_schema::{Field, Schema}; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; + +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{ + internal_err, DFSchema, DataFusionError, ScalarValue, ToDFSchema, +}; +use datafusion_expr::{ + expr::ScalarFunction, ColumnarValue, ExprSchemable, ScalarUDF, ScalarUDFImpl, + Signature, +}; + +#[derive(Debug)] +struct UDFWithExprReturn { + signature: Signature, +} + +impl UDFWithExprReturn { + fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + } + } +} + +//Implement the ScalarUDFImpl trait for UDFWithExprReturn +impl ScalarUDFImpl for UDFWithExprReturn { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "udf_with_expr_return" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Int32) + } + // An example of how to use the exprs to determine the return type + // If the third argument is '0', return the type of the first argument + // If the third argument is '1', return the type of the second argument + fn return_type_from_exprs( + &self, + arg_exprs: &[Expr], + schema: &DFSchema, + ) -> Result { + if arg_exprs.len() != 3 { + return internal_err!("The size of the args must be 3."); + } + let take_idx = match arg_exprs.get(2).unwrap() { + Expr::Literal(ScalarValue::Int64(Some(idx))) if (idx == &0 || idx == &1) => { + *idx as usize + } + _ => unreachable!(), + }; + arg_exprs.get(take_idx).unwrap().get_type(schema) + } + // The actual implementation would add one to the argument + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } +} + +#[derive(Debug)] +struct UDFDefault { + signature: Signature, +} + +impl UDFDefault { + fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + } + } +} + +// Implement the ScalarUDFImpl trait for UDFDefault +// This is the same as UDFWithExprReturn, except without return_type_from_exprs +impl ScalarUDFImpl for UDFDefault { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "udf_default" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + // The actual implementation would add one to the argument + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Create a new ScalarUDF from the implementation + let udf_with_expr_return = ScalarUDF::from(UDFWithExprReturn::new()); + + // Call 'return_type' to get the return type of the function + let ret = udf_with_expr_return.return_type(&[DataType::Int32])?; + assert_eq!(ret, DataType::Int32); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float64, false), + ]) + .to_dfschema()?; + + // Set the third argument to 0 to return the type of the first argument + let expr0 = udf_with_expr_return.call(vec![col("a"), col("b"), lit(0_i64)]); + let args = match expr0 { + Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args, + _ => panic!("Expected ScalarFunction"), + }; + let ret = udf_with_expr_return.return_type_from_exprs(&args, &schema)?; + // The return type should be the same as the first argument + assert_eq!(ret, DataType::Float32); + + // Set the third argument to 1 to return the type of the second argument + let expr1 = udf_with_expr_return.call(vec![col("a"), col("b"), lit(1_i64)]); + let args1 = match expr1 { + Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args, + _ => panic!("Expected ScalarFunction"), + }; + let ret = udf_with_expr_return.return_type_from_exprs(&args1, &schema)?; + // The return type should be the same as the second argument + assert_eq!(ret, DataType::Float64); + + // Create a new ScalarUDF from the implementation + let udf_default = ScalarUDF::from(UDFDefault::new()); + // Call 'return_type' to get the return type of the function + let ret = udf_default.return_type(&[DataType::Int32])?; + assert_eq!(ret, DataType::Boolean); + + // Set the third argument to 0 to return the type of the first argument + let expr2 = udf_default.call(vec![col("a"), col("b"), lit(0_i64)]); + let args = match expr2 { + Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args, + _ => panic!("Expected ScalarFunction"), + }; + let ret = udf_default.return_type_from_exprs(&args, &schema)?; + assert_eq!(ret, DataType::Boolean); + + Ok(()) +} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3017e1ec0271..fb636e35747d 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,12 +17,13 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use crate::ExprSchemable; use crate::{ ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{DFSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; @@ -152,6 +153,17 @@ impl ScalarUDF { self.inner.return_type(args) } + /// The datatype this function returns given the input argument input types. + /// This function is used when the input arguments are [`Expr`]s. + /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. + pub fn return_type_from_exprs( + &self, + args: &[Expr], + schema: &DFSchema, + ) -> Result { + self.inner.return_type_from_exprs(args, schema) + } + /// Invoke the function on `args`, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke`] for more details. @@ -249,6 +261,22 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// the arguments fn return_type(&self, arg_types: &[DataType]) -> Result; + /// What [`DataType`] will be returned by this function, given the types of + /// the expr arguments + fn return_type_from_exprs( + &self, + arg_exprs: &[Expr], + schema: &DFSchema, + ) -> Result { + // provide default implementation that calls `self.return_type()` + // so that people don't have to implement `return_type_from_exprs` if they dont want to + let arg_types = arg_exprs + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + self.return_type(&arg_types) + } + /// Invoke the function on `args`, returning the appropriate result /// /// The function will be invoked passed with the slice of [`ColumnarValue`] From 17a2c9189460adff7993e54b8bccfada4dbd0a0a Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 24 Jan 2024 13:20:33 -0600 Subject: [PATCH 02/17] change file name --- .../examples/{complex_udf.rs => return_types_udf.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename datafusion-examples/examples/{complex_udf.rs => return_types_udf.rs} (100%) diff --git a/datafusion-examples/examples/complex_udf.rs b/datafusion-examples/examples/return_types_udf.rs similarity index 100% rename from datafusion-examples/examples/complex_udf.rs rename to datafusion-examples/examples/return_types_udf.rs From 56b71ae5d40dbed249e3b6822b2f97c1637da45e Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Thu, 8 Feb 2024 21:28:57 -0600 Subject: [PATCH 03/17] try using ?Sized --- datafusion-examples/examples/return_types_udf.rs | 4 ++-- datafusion/expr/src/expr_schema.rs | 8 ++++---- datafusion/expr/src/udf.rs | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/datafusion-examples/examples/return_types_udf.rs b/datafusion-examples/examples/return_types_udf.rs index ac18fbff3dd4..f7cbeaf88f9f 100644 --- a/datafusion-examples/examples/return_types_udf.rs +++ b/datafusion-examples/examples/return_types_udf.rs @@ -23,7 +23,7 @@ use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{ - internal_err, DFSchema, DataFusionError, ScalarValue, ToDFSchema, + internal_err, DataFusionError, ExprSchema, ScalarValue, ToDFSchema }; use datafusion_expr::{ expr::ScalarFunction, ColumnarValue, ExprSchemable, ScalarUDF, ScalarUDFImpl, @@ -63,7 +63,7 @@ impl ScalarUDFImpl for UDFWithExprReturn { fn return_type_from_exprs( &self, arg_exprs: &[Expr], - schema: &DFSchema, + schema: &dyn ExprSchema, ) -> Result { if arg_exprs.len() != 3 { return internal_err!("The size of the args must be 3."); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 517d7a35f70a..6b718a694e56 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -37,7 +37,7 @@ use std::sync::Arc; /// trait to allow expr to typable with respect to a schema pub trait ExprSchemable { /// given a schema, return the type of the expr - fn get_type(&self, schema: &S) -> Result; + fn get_type(&self, schema: &S) -> Result; /// given a schema, return the nullability of the expr fn nullable(&self, input_schema: &S) -> Result; @@ -90,7 +90,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). - fn get_type(&self, schema: &S) -> Result { + fn get_type(&self, schema: &S) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { @@ -136,7 +136,7 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - Ok(fun.return_type(&arg_data_types)?) + fun.return_type_from_exprs(args, schema) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -394,7 +394,7 @@ impl ExprSchemable for Expr { } /// return the schema [`Field`] for the type referenced by `get_indexed_field` -fn field_for_index( +fn field_for_index( expr: &Expr, field: &GetFieldAccess, schema: &S, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index fb636e35747d..d7f5fb746ada 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,7 +23,7 @@ use crate::{ ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::{DFSchema, ExprSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; @@ -159,7 +159,7 @@ impl ScalarUDF { pub fn return_type_from_exprs( &self, args: &[Expr], - schema: &DFSchema, + schema: &dyn ExprSchema, ) -> Result { self.inner.return_type_from_exprs(args, schema) } @@ -266,7 +266,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn return_type_from_exprs( &self, arg_exprs: &[Expr], - schema: &DFSchema, + schema: &dyn ExprSchema, ) -> Result { // provide default implementation that calls `self.return_type()` // so that people don't have to implement `return_type_from_exprs` if they dont want to From 3dbc0c78351f4dd516b2a6f8969285ae9a4acec0 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Thu, 8 Feb 2024 21:49:10 -0600 Subject: [PATCH 04/17] use Ok --- datafusion/expr/src/expr_schema.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 6b718a694e56..20f929c17065 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -136,7 +136,7 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - fun.return_type_from_exprs(args, schema) + Ok(fun.return_type_from_exprs(args, schema)?) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") From 491a4a110d8cc828c5d6783385b3c7ce50859867 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Fri, 9 Feb 2024 15:50:48 -0600 Subject: [PATCH 05/17] move method default impl outside trait --- .../examples/return_types_udf.rs | 151 ++++++++---------- datafusion/expr/src/expr_schema.rs | 8 +- datafusion/expr/src/udf.rs | 29 ++-- 3 files changed, 85 insertions(+), 103 deletions(-) diff --git a/datafusion-examples/examples/return_types_udf.rs b/datafusion-examples/examples/return_types_udf.rs index f7cbeaf88f9f..95dd5da5332b 100644 --- a/datafusion-examples/examples/return_types_udf.rs +++ b/datafusion-examples/examples/return_types_udf.rs @@ -15,20 +15,48 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - -use arrow_schema::{Field, Schema}; -use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use datafusion::{ + arrow::{ + array::{Float32Array, Float64Array}, + datatypes::DataType, + record_batch::RecordBatch, + }, + logical_expr::Volatility, +}; use datafusion::error::Result; use datafusion::prelude::*; -use datafusion_common::{ - internal_err, DataFusionError, ExprSchema, ScalarValue, ToDFSchema -}; +use datafusion_common::{ExprSchema, ScalarValue}; use datafusion_expr::{ - expr::ScalarFunction, ColumnarValue, ExprSchemable, ScalarUDF, ScalarUDFImpl, - Signature, + ColumnarValue, ExprSchemable, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature }; +use std::{any::Any, sync::Arc}; + +// create local execution context with an in-memory table +fn create_context() -> Result { + use datafusion::arrow::datatypes::{Field, Schema}; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float64, false), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1, 6.1])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0])), + ], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + ctx.register_batch("t", batch)?; + Ok(ctx) +} #[derive(Debug)] struct UDFWithExprReturn { @@ -49,13 +77,13 @@ impl ScalarUDFImpl for UDFWithExprReturn { self } fn name(&self) -> &str { - "udf_with_expr_return" + "my_cast" } fn signature(&self) -> &Signature { &self.signature } fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Int32) + Ok(DataType::Boolean) } // An example of how to use the exprs to determine the return type // If the third argument is '0', return the type of the first argument @@ -64,7 +92,7 @@ impl ScalarUDFImpl for UDFWithExprReturn { &self, arg_exprs: &[Expr], schema: &dyn ExprSchema, - ) -> Result { + ) -> Option> { if arg_exprs.len() != 3 { return internal_err!("The size of the args must be 3."); } @@ -74,97 +102,44 @@ impl ScalarUDFImpl for UDFWithExprReturn { } _ => unreachable!(), }; - arg_exprs.get(take_idx).unwrap().get_type(schema) + Some(arg_exprs.get(take_idx).unwrap().get_type(schema)) } // The actual implementation would add one to the argument fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!() + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))) } } -#[derive(Debug)] -struct UDFDefault { - signature: Signature, -} -impl UDFDefault { - fn new() -> Self { - Self { - signature: Signature::any(3, Volatility::Immutable), - } - } -} - -// Implement the ScalarUDFImpl trait for UDFDefault -// This is the same as UDFWithExprReturn, except without return_type_from_exprs -impl ScalarUDFImpl for UDFDefault { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "udf_default" - } - fn signature(&self) -> &Signature { - &self.signature - } - fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Boolean) - } - // The actual implementation would add one to the argument - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!() - } -} #[tokio::main] async fn main() -> Result<()> { // Create a new ScalarUDF from the implementation let udf_with_expr_return = ScalarUDF::from(UDFWithExprReturn::new()); - // Call 'return_type' to get the return type of the function - let ret = udf_with_expr_return.return_type(&[DataType::Int32])?; - assert_eq!(ret, DataType::Int32); + let ctx = create_context()?; - let schema = Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float64, false), - ]) - .to_dfschema()?; - - // Set the third argument to 0 to return the type of the first argument - let expr0 = udf_with_expr_return.call(vec![col("a"), col("b"), lit(0_i64)]); - let args = match expr0 { - Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args, - _ => panic!("Expected ScalarFunction"), - }; - let ret = udf_with_expr_return.return_type_from_exprs(&args, &schema)?; - // The return type should be the same as the first argument - assert_eq!(ret, DataType::Float32); - - // Set the third argument to 1 to return the type of the second argument - let expr1 = udf_with_expr_return.call(vec![col("a"), col("b"), lit(1_i64)]); - let args1 = match expr1 { - Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args, - _ => panic!("Expected ScalarFunction"), - }; - let ret = udf_with_expr_return.return_type_from_exprs(&args1, &schema)?; - // The return type should be the same as the second argument - assert_eq!(ret, DataType::Float64); + ctx.register_udf(udf_with_expr_return); - // Create a new ScalarUDF from the implementation - let udf_default = ScalarUDF::from(UDFDefault::new()); - // Call 'return_type' to get the return type of the function - let ret = udf_default.return_type(&[DataType::Int32])?; - assert_eq!(ret, DataType::Boolean); - - // Set the third argument to 0 to return the type of the first argument - let expr2 = udf_default.call(vec![col("a"), col("b"), lit(0_i64)]); - let args = match expr2 { - Expr::ScalarFunction(ScalarFunction { func_def: _, args }) => args, - _ => panic!("Expected ScalarFunction"), - }; - let ret = udf_default.return_type_from_exprs(&args, &schema)?; - assert_eq!(ret, DataType::Boolean); + // SELECT take(a, b, 0) AS take0, take(a, b, 1) AS take1 FROM t; + let df = ctx.table("t").await?; + let take = df.registry().udf("my_cast")?; + let expr0 = take + .call(vec![col("a"), lit("i32")]) + .alias("take0"); + let expr1 = take + .call(vec![col("a"), lit("i64")]) + .alias("take1"); + + let df = df.select(vec![expr0, expr1])?; + let schema = df.schema(); + + // Check output schema + assert_eq!(schema.field(0).data_type(), &DataType::Int32); + assert_eq!(schema.field(1).data_type(), &DataType::Int64); + + df.show().await?; + Ok(()) } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 20f929c17065..d45d6cd8341e 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -37,7 +37,7 @@ use std::sync::Arc; /// trait to allow expr to typable with respect to a schema pub trait ExprSchemable { /// given a schema, return the type of the expr - fn get_type(&self, schema: &S) -> Result; + fn get_type(&self, schema: &S) -> Result; /// given a schema, return the nullability of the expr fn nullable(&self, input_schema: &S) -> Result; @@ -90,7 +90,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). - fn get_type(&self, schema: &S) -> Result { + fn get_type(&self, schema: &S) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { @@ -136,7 +136,7 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - Ok(fun.return_type_from_exprs(args, schema)?) + Ok(fun.return_type_from_exprs(&args, schema)?) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -394,7 +394,7 @@ impl ExprSchemable for Expr { } /// return the schema [`Field`] for the type referenced by `get_indexed_field` -fn field_for_index( +fn field_for_index( expr: &Expr, field: &GetFieldAccess, schema: &S, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index d7f5fb746ada..8e1651967cce 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,7 +23,7 @@ use crate::{ ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{DFSchema, ExprSchema, Result}; +use datafusion_common::{ExprSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; @@ -155,13 +155,24 @@ impl ScalarUDF { /// The datatype this function returns given the input argument input types. /// This function is used when the input arguments are [`Expr`]s. + /// /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. - pub fn return_type_from_exprs( + pub fn return_type_from_exprs( &self, args: &[Expr], - schema: &dyn ExprSchema, + schema: &S, ) -> Result { - self.inner.return_type_from_exprs(args, schema) + // If the implementation provides a return_type_from_exprs, use it + if let Some(return_type) = self.inner.return_type_from_exprs(args, schema) { + return_type + // Otherwise, use the return_type function + } else { + let arg_types = args + .iter() + .map(|arg| arg.get_type(schema)) + .collect::>>()?; + self.return_type(&arg_types) + } } /// Invoke the function on `args`, returning the appropriate result. @@ -267,14 +278,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { &self, arg_exprs: &[Expr], schema: &dyn ExprSchema, - ) -> Result { - // provide default implementation that calls `self.return_type()` + ) -> Option> { + // The default implementation returns None // so that people don't have to implement `return_type_from_exprs` if they dont want to - let arg_types = arg_exprs - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - self.return_type(&arg_types) + None } /// Invoke the function on `args`, returning the appropriate result From 468b38f6d523fd1b49b831d756b3b86131e8bcbf Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Sat, 10 Feb 2024 23:46:37 -0600 Subject: [PATCH 06/17] Use type trait for ExprSchemable --- .../examples/return_types_udf.rs | 31 +++++++++++-------- datafusion/expr/src/expr_schema.rs | 29 ++++++++--------- datafusion/expr/src/udf.rs | 8 ++--- datafusion/physical-expr/src/udf.rs | 4 +-- 4 files changed, 39 insertions(+), 33 deletions(-) diff --git a/datafusion-examples/examples/return_types_udf.rs b/datafusion-examples/examples/return_types_udf.rs index 95dd5da5332b..b82550fef7b1 100644 --- a/datafusion-examples/examples/return_types_udf.rs +++ b/datafusion-examples/examples/return_types_udf.rs @@ -26,7 +26,7 @@ use datafusion::{ use datafusion::error::Result; use datafusion::prelude::*; -use datafusion_common::{ExprSchema, ScalarValue}; +use datafusion_common::{internal_err, DataFusionError, DFSchema, ExprSchema, ScalarValue}; use datafusion_expr::{ ColumnarValue, ExprSchemable, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature }; @@ -82,8 +82,8 @@ impl ScalarUDFImpl for UDFWithExprReturn { fn signature(&self) -> &Signature { &self.signature } - fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Boolean) + fn return_type(&self, args: &[DataType]) -> Result { + Ok(DataType::Float32) } // An example of how to use the exprs to determine the return type // If the third argument is '0', return the type of the first argument @@ -91,10 +91,10 @@ impl ScalarUDFImpl for UDFWithExprReturn { fn return_type_from_exprs( &self, arg_exprs: &[Expr], - schema: &dyn ExprSchema, + schema: &DFSchema, ) -> Option> { if arg_exprs.len() != 3 { - return internal_err!("The size of the args must be 3."); + return Some(internal_err!("The size of the args must be 3.")); } let take_idx = match arg_exprs.get(2).unwrap() { Expr::Literal(ScalarValue::Int64(Some(idx))) if (idx == &0 || idx == &1) => { @@ -105,8 +105,15 @@ impl ScalarUDFImpl for UDFWithExprReturn { Some(arg_exprs.get(take_idx).unwrap().get_type(schema)) } // The actual implementation would add one to the argument - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0)))) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let take_idx = match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, + _ => unreachable!(), + }; + match &args[take_idx] { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())), + ColumnarValue::Scalar(_) => unimplemented!(), + } } } @@ -125,18 +132,16 @@ async fn main() -> Result<()> { let df = ctx.table("t").await?; let take = df.registry().udf("my_cast")?; let expr0 = take - .call(vec![col("a"), lit("i32")]) - .alias("take0"); + .call(vec![col("a"), col("b"), lit(0_i64)]); let expr1 = take - .call(vec![col("a"), lit("i64")]) - .alias("take1"); + .call(vec![col("a"), col("b"), lit(1_i64)]); let df = df.select(vec![expr0, expr1])?; let schema = df.schema(); // Check output schema - assert_eq!(schema.field(0).data_type(), &DataType::Int32); - assert_eq!(schema.field(1).data_type(), &DataType::Int64); + assert_eq!(schema.field(0).data_type(), &DataType::Float32); + assert_eq!(schema.field(1).data_type(), &DataType::Float64); df.show().await?; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index d45d6cd8341e..3b13ea7d2127 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -35,24 +35,24 @@ use std::collections::HashMap; use std::sync::Arc; /// trait to allow expr to typable with respect to a schema -pub trait ExprSchemable { +pub trait ExprSchemable { /// given a schema, return the type of the expr - fn get_type(&self, schema: &S) -> Result; + fn get_type(&self, schema: &S) -> Result; /// given a schema, return the nullability of the expr - fn nullable(&self, input_schema: &S) -> Result; + fn nullable(&self, input_schema: &S) -> Result; /// given a schema, return the expr's optional metadata - fn metadata(&self, schema: &S) -> Result>; + fn metadata(&self, schema: &S) -> Result>; /// convert to a field with respect to a schema fn to_field(&self, input_schema: &DFSchema) -> Result; /// cast to a type with respect to a schema - fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; + fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; } -impl ExprSchemable for Expr { +impl ExprSchemable for Expr { /// Returns the [arrow::datatypes::DataType] of the expression /// based on [ExprSchema] /// @@ -90,7 +90,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). - fn get_type(&self, schema: &S) -> Result { + fn get_type(&self, schema: &DFSchema) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { @@ -136,7 +136,8 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - Ok(fun.return_type_from_exprs(&args, schema)?) + let t = fun.return_type_from_exprs(&args, schema)?; + Ok(t) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -220,7 +221,7 @@ impl ExprSchemable for Expr { /// This function errors when it is not possible to compute its /// nullability. This happens when the expression refers to a /// column that does not exist in the schema. - fn nullable(&self, input_schema: &S) -> Result { + fn nullable(&self, input_schema: &DFSchema) -> Result { match self { Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) @@ -327,9 +328,9 @@ impl ExprSchemable for Expr { } } - fn metadata(&self, schema: &S) -> Result> { + fn metadata(&self, schema: &DFSchema) -> Result> { match self { - Expr::Column(c) => Ok(schema.metadata(c)?.clone()), + Expr::Column(c) => Ok(schema.metadata().clone()), Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), _ => Ok(HashMap::new()), } @@ -370,7 +371,7 @@ impl ExprSchemable for Expr { /// /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. - fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result { + fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { return Ok(self); @@ -394,10 +395,10 @@ impl ExprSchemable for Expr { } /// return the schema [`Field`] for the type referenced by `get_indexed_field` -fn field_for_index( +fn field_for_index( expr: &Expr, field: &GetFieldAccess, - schema: &S, + schema: &DFSchema, ) -> Result { let expr_dt = expr.get_type(schema)?; match field { diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 8e1651967cce..1cf3614d75d7 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,7 +23,7 @@ use crate::{ ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{ExprSchema, Result}; +use datafusion_common::{DFSchema, ExprSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; @@ -157,10 +157,10 @@ impl ScalarUDF { /// This function is used when the input arguments are [`Expr`]s. /// /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. - pub fn return_type_from_exprs( + pub fn return_type_from_exprs( &self, args: &[Expr], - schema: &S, + schema: &DFSchema, ) -> Result { // If the implementation provides a return_type_from_exprs, use it if let Some(return_type) = self.inner.return_type_from_exprs(args, schema) { @@ -277,7 +277,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn return_type_from_exprs( &self, arg_exprs: &[Expr], - schema: &dyn ExprSchema, + schema: &DFSchema, ) -> Option> { // The default implementation returns None // so that people don't have to implement `return_type_from_exprs` if they dont want to diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index e0117fecb4e8..d42a8f21d870 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -18,7 +18,7 @@ //! UDF support use crate::{PhysicalExpr, ScalarFunctionExpr}; use arrow::datatypes::Schema; -use datafusion_common::Result; +use datafusion_common::{schema_datafusion_err, Result}; pub use datafusion_expr::ScalarUDF; use std::sync::Arc; @@ -38,7 +38,7 @@ pub fn create_physical_expr( fun.name(), fun.fun(), input_phy_exprs.to_vec(), - fun.return_type(&input_exprs_types)?, + fun.return_type_from_exprs(&input_exprs_types, input_schema)?, fun.monotonicity()?, fun.signature().type_signature.supports_zero_argument(), ))) From 5772d9f83f93bb1d932bd16bc26deac8e36844aa Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Sat, 10 Feb 2024 23:54:17 -0600 Subject: [PATCH 07/17] fix nit --- datafusion/physical-expr/src/udf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index d42a8f21d870..72917eceb952 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -38,7 +38,7 @@ pub fn create_physical_expr( fun.name(), fun.fun(), input_phy_exprs.to_vec(), - fun.return_type_from_exprs(&input_exprs_types, input_schema)?, + fun.return_type_from_exprs(&input_phy_exprs, input_schema)?, fun.monotonicity()?, fun.signature().type_signature.supports_zero_argument(), ))) From 59b395818fbf2f957c0f7f7b6a3cfc2f199cea8c Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 12 Feb 2024 16:18:28 -0500 Subject: [PATCH 08/17] Proposed Return Type from Expr suggestions (#1) * Improve return_type_from_args * Rework example * Update datafusion/core/tests/user_defined/user_defined_scalar_functions.rs --------- Co-authored-by: Junhao Liu --- .../examples/return_types_udf.rs | 150 ------------------ .../user_defined_scalar_functions.rs | 142 ++++++++++++++++- datafusion/expr/src/expr_schema.rs | 2 +- datafusion/expr/src/udf.rs | 67 ++++---- datafusion/physical-expr/src/planner.rs | 14 +- datafusion/physical-expr/src/udf.rs | 16 +- 6 files changed, 193 insertions(+), 198 deletions(-) delete mode 100644 datafusion-examples/examples/return_types_udf.rs diff --git a/datafusion-examples/examples/return_types_udf.rs b/datafusion-examples/examples/return_types_udf.rs deleted file mode 100644 index b82550fef7b1..000000000000 --- a/datafusion-examples/examples/return_types_udf.rs +++ /dev/null @@ -1,150 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::{ - arrow::{ - array::{Float32Array, Float64Array}, - datatypes::DataType, - record_batch::RecordBatch, - }, - logical_expr::Volatility, -}; - -use datafusion::error::Result; -use datafusion::prelude::*; -use datafusion_common::{internal_err, DataFusionError, DFSchema, ExprSchema, ScalarValue}; -use datafusion_expr::{ - ColumnarValue, ExprSchemable, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature -}; -use std::{any::Any, sync::Arc}; - -// create local execution context with an in-memory table -fn create_context() -> Result { - use datafusion::arrow::datatypes::{Field, Schema}; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float64, false), - ])); - - // define data. - let batch = RecordBatch::try_new( - schema, - vec![ - Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1, 6.1])), - Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0])), - ], - )?; - - // declare a new context. In spark API, this corresponds to a new spark SQLsession - let ctx = SessionContext::new(); - - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - ctx.register_batch("t", batch)?; - Ok(ctx) -} - -#[derive(Debug)] -struct UDFWithExprReturn { - signature: Signature, -} - -impl UDFWithExprReturn { - fn new() -> Self { - Self { - signature: Signature::any(3, Volatility::Immutable), - } - } -} - -//Implement the ScalarUDFImpl trait for UDFWithExprReturn -impl ScalarUDFImpl for UDFWithExprReturn { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "my_cast" - } - fn signature(&self) -> &Signature { - &self.signature - } - fn return_type(&self, args: &[DataType]) -> Result { - Ok(DataType::Float32) - } - // An example of how to use the exprs to determine the return type - // If the third argument is '0', return the type of the first argument - // If the third argument is '1', return the type of the second argument - fn return_type_from_exprs( - &self, - arg_exprs: &[Expr], - schema: &DFSchema, - ) -> Option> { - if arg_exprs.len() != 3 { - return Some(internal_err!("The size of the args must be 3.")); - } - let take_idx = match arg_exprs.get(2).unwrap() { - Expr::Literal(ScalarValue::Int64(Some(idx))) if (idx == &0 || idx == &1) => { - *idx as usize - } - _ => unreachable!(), - }; - Some(arg_exprs.get(take_idx).unwrap().get_type(schema)) - } - // The actual implementation would add one to the argument - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let take_idx = match &args[2] { - ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, - _ => unreachable!(), - }; - match &args[take_idx] { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())), - ColumnarValue::Scalar(_) => unimplemented!(), - } - } -} - - - -#[tokio::main] -async fn main() -> Result<()> { - // Create a new ScalarUDF from the implementation - let udf_with_expr_return = ScalarUDF::from(UDFWithExprReturn::new()); - - let ctx = create_context()?; - - ctx.register_udf(udf_with_expr_return); - - // SELECT take(a, b, 0) AS take0, take(a, b, 1) AS take1 FROM t; - let df = ctx.table("t").await?; - let take = df.registry().udf("my_cast")?; - let expr0 = take - .call(vec![col("a"), col("b"), lit(0_i64)]); - let expr1 = take - .call(vec![col("a"), col("b"), lit(1_i64)]); - - let df = df.select(vec![expr0, expr1])?; - let schema = df.schema(); - - // Check output schema - assert_eq!(schema.field(0).data_type(), &DataType::Float32); - assert_eq!(schema.field(1).data_type(), &DataType::Float64); - - df.show().await?; - - - Ok(()) -} diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index a86c76b9b6dd..f3bd085aab16 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -22,12 +22,16 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::as_float64_array; -use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue}; +use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, + plan_err, DFSchema, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::{ - create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable, + LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use rand::{thread_rng, Rng}; +use std::any::Any; use std::iter; use std::sync::Arc; @@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> { Ok(()) } +#[derive(Debug)] +struct TakeUDF { + signature: Signature, +} + +impl TakeUDF { + fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + } + } +} + +/// Implement a ScalarUDFImpl whose return type is a function of the input values +impl ScalarUDFImpl for TakeUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "take" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + not_impl_err!("Not called because the return_type_from_exprs is implemented") + } + + /// Thus function returns the type of the first or second argument based on + /// the third argument: + /// + /// 1. If the third argument is '0', return the type of the first argument + /// 2. If the third argument is '1', return the type of the second argument + fn return_type_from_exprs( + &self, + arg_exprs: &[Expr], + schema: &DFSchema, + ) -> Result { + if arg_exprs.len() != 3 { + return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len()); + } + + let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) = + arg_exprs.get(2) + { + if *idx == 0 || *idx == 1 { + *idx as usize + } else { + return plan_err!("The third argument must be 0 or 1, got: {idx}"); + } + } else { + return plan_err!( + "The third argument must be a literal of type int64, but got {:?}", + arg_exprs.get(2) + ); + }; + + arg_exprs.get(take_idx).unwrap().get_type(schema) + } + + // The actual implementation rethr + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let take_idx = match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, + _ => unreachable!(), + }; + match &args[take_idx] { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())), + ColumnarValue::Scalar(_) => unimplemented!(), + } + } +} + +#[tokio::test] +async fn verify_udf_return_type() -> Result<()> { + // Create a new ScalarUDF from the implementation + let take = ScalarUDF::from(TakeUDF::new()); + + // SELECT + // take(smallint_col, double_col, 0) as take0, + // take(smallint_col, double_col, 1) as take1 + // FROM alltypes_plain; + let exprs = vec![ + take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)]) + .alias("take0"), + take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)]) + .alias("take1"), + ]; + + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await?; + + let df = ctx.table("alltypes_plain").await?.select(exprs)?; + + let schema = df.schema(); + + // The output schema should be + // * type of column smallint_col (float64) + // * type of column double_col (float32) + assert_eq!(schema.field(0).data_type(), &DataType::Int32); + assert_eq!(schema.field(1).data_type(), &DataType::Float64); + + let expected = [ + "+-------+-------+", + "| take0 | take1 |", + "+-------+-------+", + "| 0 | 0.0 |", + "| 0 | 0.0 |", + "| 0 | 0.0 |", + "| 0 | 0.0 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "+-------+-------+", + ]; + assert_batches_sorted_eq!(&expected, &df.collect().await?); + + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF @@ -531,6 +656,17 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { Ok(()) } +async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + Ok(()) +} + /// Execute SQL and return results as a RecordBatch async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3b13ea7d2127..bd5f050ddedd 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -330,7 +330,7 @@ impl ExprSchemable for Expr { fn metadata(&self, schema: &DFSchema) -> Result> { match self { - Expr::Column(c) => Ok(schema.metadata().clone()), + Expr::Column(_) => Ok(schema.metadata().clone()), Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), _ => Ok(HashMap::new()), } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 1cf3614d75d7..361285ad9e47 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,7 +23,7 @@ use crate::{ ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{DFSchema, ExprSchema, Result}; +use datafusion_common::{DFSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; @@ -111,7 +111,7 @@ impl ScalarUDF { /// /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases)) + Self::new_from_impl(AliasedScalarUDFImpl::new(self.inner.clone(), aliases)) } /// Returns a [`Expr`] logical expression to call this UDF with specified @@ -146,16 +146,9 @@ impl ScalarUDF { self.inner.signature() } - /// The datatype this function returns given the input argument input types. - /// - /// See [`ScalarUDFImpl::return_type`] for more details. - pub fn return_type(&self, args: &[DataType]) -> Result { - self.inner.return_type(args) - } - /// The datatype this function returns given the input argument input types. /// This function is used when the input arguments are [`Expr`]s. - /// + /// /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. pub fn return_type_from_exprs( &self, @@ -163,16 +156,7 @@ impl ScalarUDF { schema: &DFSchema, ) -> Result { // If the implementation provides a return_type_from_exprs, use it - if let Some(return_type) = self.inner.return_type_from_exprs(args, schema) { - return_type - // Otherwise, use the return_type function - } else { - let arg_types = args - .iter() - .map(|arg| arg.get_type(schema)) - .collect::>>()?; - self.return_type(&arg_types) - } + self.inner.return_type_from_exprs(args, schema) } /// Invoke the function on `args`, returning the appropriate result. @@ -272,16 +256,41 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// the arguments fn return_type(&self, arg_types: &[DataType]) -> Result; - /// What [`DataType`] will be returned by this function, given the types of - /// the expr arguments + /// What [`DataType`] will be returned by this function, given the + /// arguments? + /// + /// Note most UDFs should implement [`Self::return_type`] and not this + /// function. The output type for most functions only depends on the types + /// of their inputs (e.g. `sqrt(f32)` is always `f32`). + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// This method can be overridden for functions that return different + /// *types* based on the *values* of their arguments. + /// + /// For example, the following two function calls get the same argument + /// types (something and a `Utf8` string) but return different types based + /// on the value of the second argument: + /// + /// * `arrow_cast(x, 'Int16')` --> `Int16` + /// * `arrow_cast(x, 'Float32')` --> `Float32` + /// + /// # Notes: + /// + /// This function must consistently return the same type for the same + /// logical input even if the input is simplified (e.g. it must return the same + /// value for `('foo' | 'bar')` as it does for ('foobar'). fn return_type_from_exprs( &self, - arg_exprs: &[Expr], + args: &[Expr], schema: &DFSchema, - ) -> Option> { - // The default implementation returns None - // so that people don't have to implement `return_type_from_exprs` if they dont want to - None + ) -> Result { + let arg_types = args + .iter() + .map(|arg| arg.get_type(schema)) + .collect::>>()?; + self.return_type(&arg_types) } /// Invoke the function on `args`, returning the appropriate result @@ -325,13 +334,13 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. #[derive(Debug)] struct AliasedScalarUDFImpl { - inner: ScalarUDF, + inner: Arc, aliases: Vec, } impl AliasedScalarUDFImpl { pub fn new( - inner: ScalarUDF, + inner: Arc, new_aliases: impl IntoIterator, ) -> Self { let mut aliases = inner.aliases().to_vec(); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 6408af5cda99..b8491aea2d6f 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -272,11 +272,15 @@ pub fn create_physical_expr( execution_props, ) } - ScalarFunctionDefinition::UDF(fun) => udf::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - input_schema, - ), + ScalarFunctionDefinition::UDF(fun) => { + let return_type = fun.return_type_from_exprs(args, input_dfschema)?; + + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + return_type, + ) + } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") } diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index 72917eceb952..05d9f99c422d 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -17,28 +17,24 @@ //! UDF support use crate::{PhysicalExpr, ScalarFunctionExpr}; -use arrow::datatypes::Schema; -use datafusion_common::{schema_datafusion_err, Result}; +use arrow_schema::DataType; +use datafusion_common::Result; pub use datafusion_expr::ScalarUDF; use std::sync::Arc; /// Create a physical expression of the UDF. -/// This function errors when `args`' can't be coerced to a valid argument type of the UDF. +/// +/// Arguments: pub fn create_physical_expr( fun: &ScalarUDF, input_phy_exprs: &[Arc], - input_schema: &Schema, + return_type: DataType, ) -> Result> { - let input_exprs_types = input_phy_exprs - .iter() - .map(|e| e.data_type(input_schema)) - .collect::>>()?; - Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), fun.fun(), input_phy_exprs.to_vec(), - fun.return_type_from_exprs(&input_phy_exprs, input_schema)?, + return_type, fun.monotonicity()?, fun.signature().type_signature.supports_zero_argument(), ))) From f195fbaa2ef7ddbd49ce2e3b599fddbb299a4be5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 12 Feb 2024 16:33:34 -0500 Subject: [PATCH 09/17] Apply suggestions from code review Co-authored-by: Alex Huang --- datafusion/expr/src/udf.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 361285ad9e47..3f7f6a76bb24 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -149,6 +149,7 @@ impl ScalarUDF { /// The datatype this function returns given the input argument input types. /// This function is used when the input arguments are [`Expr`]s. /// + /// /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. pub fn return_type_from_exprs( &self, From 21d495f73b79851ae984099eaea142c78f69a66e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 12 Feb 2024 17:24:28 -0500 Subject: [PATCH 10/17] Fix tests + clippy --- datafusion/expr/src/expr_schema.rs | 61 +++++++++---------- .../optimizer/src/analyzer/type_coercion.rs | 16 ++--- datafusion/physical-expr/src/udf.rs | 3 +- 3 files changed, 39 insertions(+), 41 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index bd5f050ddedd..dcaa4760c5a1 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -136,7 +136,7 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - let t = fun.return_type_from_exprs(&args, schema)?; + let t = fun.return_type_from_exprs(args, schema)?; Ok(t) } ScalarFunctionDefinition::Name(_) => { @@ -458,21 +458,23 @@ mod tests { use super::*; use crate::{col, lit}; use arrow::datatypes::{DataType, Fields}; - use datafusion_common::{Column, ScalarValue, TableReference}; + use datafusion_common::{ScalarValue, TableReference}; macro_rules! test_is_expr_nullable { ($EXPR_TYPE:ident) => {{ let expr = lit(ScalarValue::Null).$EXPR_TYPE(); - assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); + assert!(!expr.nullable(&MockExprSchema::new().into_schema()).unwrap()); }}; } #[test] fn expr_schema_nullability() { let expr = col("foo").eq(lit(1)); - assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); + let mock = MockExprSchema::new(); + + assert!(!expr.nullable(&mock.clone().into_schema()).unwrap()); assert!(expr - .nullable(&MockExprSchema::new().with_nullable(true)) + .nullable(&mock.with_nullable(true).into_schema()) .unwrap()); test_is_expr_nullable!(is_null); @@ -491,6 +493,7 @@ mod tests { MockExprSchema::new() .with_data_type(DataType::Int32) .with_nullable(nullable) + .into_schema() }; let expr = col("foo").between(lit(1), lit(2)); @@ -515,14 +518,16 @@ mod tests { MockExprSchema::new() .with_data_type(DataType::Int32) .with_nullable(nullable) + .into_schema() }; let expr = col("foo").in_list(vec![lit(1); 5], false); assert!(!expr.nullable(&get_schema(false)).unwrap()); assert!(expr.nullable(&get_schema(true)).unwrap()); // Testing nullable() returns an error. + assert!(expr - .nullable(&get_schema(false).with_error_on_nullable(true)) + .nullable(&MockExprSchema::new().with_name("blarg").into_schema()) .is_err()); let null = lit(ScalarValue::Int32(None)); @@ -540,6 +545,7 @@ mod tests { MockExprSchema::new() .with_data_type(DataType::Utf8) .with_nullable(nullable) + .into_schema() }; let expr = col("foo").like(lit("bar")); @@ -555,7 +561,7 @@ mod tests { let expr = col("foo"); assert_eq!( DataType::Utf8, - expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8)) + expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8).into_schema()) .unwrap() ); } @@ -567,7 +573,8 @@ mod tests { let expr = col("foo"); let schema = MockExprSchema::new() .with_data_type(DataType::Int32) - .with_metadata(meta.clone()); + .with_metadata(meta.clone()) + .into_schema(); // col and alias should be metadata-preserving assert_eq!(meta, expr.metadata(&schema).unwrap()); @@ -586,7 +593,7 @@ mod tests { let schema = DFSchema::new_with_metadata( vec![DFField::new_unqualified("foo", DataType::Int32, true) .with_metadata(meta.clone())], - HashMap::new(), + meta.clone(), ) .unwrap(); @@ -615,24 +622,29 @@ mod tests { assert!(expr.nullable(&schema).unwrap()); } - #[derive(Debug)] + #[derive(Debug, Clone)] struct MockExprSchema { + name: String, nullable: bool, data_type: DataType, - error_on_nullable: bool, metadata: HashMap, } impl MockExprSchema { fn new() -> Self { Self { + name: "foo".to_string(), nullable: false, data_type: DataType::Null, - error_on_nullable: false, metadata: HashMap::new(), } } + fn with_name(mut self, name: impl Into) -> Self { + self.name = name.into(); + self + } + fn with_nullable(mut self, nullable: bool) -> Self { self.nullable = nullable; self @@ -643,32 +655,19 @@ mod tests { self } - fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self { - self.error_on_nullable = error_on_nullable; - self - } - fn with_metadata(mut self, metadata: HashMap) -> Self { self.metadata = metadata; self } - } - impl ExprSchema for MockExprSchema { - fn nullable(&self, _col: &Column) -> Result { - if self.error_on_nullable { - internal_err!("nullable error") - } else { - Ok(self.nullable) - } - } - fn data_type(&self, _col: &Column) -> Result<&DataType> { - Ok(&self.data_type) - } + /// Create a new schema with a single column + fn into_schema(self) -> DFSchema { + let Self {name, nullable, data_type, metadata} = self; - fn metadata(&self, _col: &Column) -> Result<&HashMap> { - Ok(&self.metadata) + let field = DFField::new_unqualified(&name, data_type, nullable) + .with_metadata(metadata.clone()); + DFSchema::new_with_metadata(vec![field], metadata).unwrap() } } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 662e0fc7c258..fba77047dd74 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -681,17 +681,17 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { let case_type = case .expr .as_ref() - .map(|expr| expr.get_type(&schema)) + .map(|expr| expr.get_type(schema)) .transpose()?; let then_types = case .when_then_expr .iter() - .map(|(_when, then)| then.get_type(&schema)) + .map(|(_when, then)| then.get_type(schema)) .collect::>>()?; let else_type = case .else_expr .as_ref() - .map(|expr| expr.get_type(&schema)) + .map(|expr| expr.get_type(schema)) .transpose()?; // find common coercible types @@ -701,7 +701,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { let when_types = case .when_then_expr .iter() - .map(|(when, _then)| when.get_type(&schema)) + .map(|(when, _then)| when.get_type(schema)) .collect::>>()?; let coerced_type = get_coerce_type_for_case_expression(&when_types, Some(case_type)); @@ -727,7 +727,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { let case_expr = case .expr .zip(case_when_coerce_type.as_ref()) - .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, &schema)) + .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema)) .transpose()? .map(Box::new); let when_then = case @@ -735,7 +735,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { .into_iter() .map(|(when, then)| { let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); - let when = when.cast_to(when_type, &schema).map_err(|e| { + let when = when.cast_to(when_type, schema).map_err(|e| { DataFusionError::Context( format!( "WHEN expressions in CASE couldn't be \ @@ -744,13 +744,13 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { Box::new(e), ) })?; - let then = then.cast_to(&then_else_coerce_type, &schema)?; + let then = then.cast_to(&then_else_coerce_type, schema)?; Ok((Box::new(when), Box::new(then))) }) .collect::>>()?; let else_expr = case .else_expr - .map(|expr| expr.cast_to(&then_else_coerce_type, &schema)) + .map(|expr| expr.cast_to(&then_else_coerce_type, schema)) .transpose()? .map(Box::new); diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index 05d9f99c422d..d9c7c9e5c2a6 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -42,7 +42,6 @@ pub fn create_physical_expr( #[cfg(test)] mod tests { - use arrow::datatypes::Schema; use arrow_schema::DataType; use datafusion_common::Result; use datafusion_expr::{ @@ -98,7 +97,7 @@ mod tests { // create and register the udf let udf = ScalarUDF::from(TestScalarUDF::new()); - let p_expr = create_physical_expr(&udf, &[], &Schema::empty())?; + let p_expr = create_physical_expr(&udf, &[], DataType::Float64)?; assert_eq!( p_expr From b2e84578aefb166097de751c37b891f4b2f3957f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 12 Feb 2024 17:43:39 -0500 Subject: [PATCH 11/17] rework types to use dyn trait --- .../user_defined_scalar_functions.rs | 7 +- datafusion/expr/src/expr_schema.rs | 92 +++++++++---------- datafusion/expr/src/udf.rs | 6 +- 3 files changed, 51 insertions(+), 54 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index f3bd085aab16..a215f0dee73d 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -22,10 +22,7 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::as_float64_array; -use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, - plan_err, DFSchema, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, ExprSchema}; use datafusion_expr::{ create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, @@ -534,7 +531,7 @@ impl ScalarUDFImpl for TakeUDF { fn return_type_from_exprs( &self, arg_exprs: &[Expr], - schema: &DFSchema, + schema: &dyn ExprSchema, ) -> Result { if arg_exprs.len() != 3 { return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len()); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index dcaa4760c5a1..d6032809cfda 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -28,31 +28,31 @@ use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, + internal_err, plan_datafusion_err, plan_err, Column, DFField, DataFusionError, ExprSchema, Result, }; use std::collections::HashMap; use std::sync::Arc; /// trait to allow expr to typable with respect to a schema -pub trait ExprSchemable { +pub trait ExprSchemable { /// given a schema, return the type of the expr - fn get_type(&self, schema: &S) -> Result; + fn get_type(&self, schema: &dyn ExprSchema) -> Result; /// given a schema, return the nullability of the expr - fn nullable(&self, input_schema: &S) -> Result; + fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; /// given a schema, return the expr's optional metadata - fn metadata(&self, schema: &S) -> Result>; + fn metadata(&self, schema: &dyn ExprSchema) -> Result>; /// convert to a field with respect to a schema - fn to_field(&self, input_schema: &DFSchema) -> Result; + fn to_field(&self, input_schema: &dyn ExprSchema) -> Result; /// cast to a type with respect to a schema - fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; + fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; } -impl ExprSchemable for Expr { +impl ExprSchemable for Expr { /// Returns the [arrow::datatypes::DataType] of the expression /// based on [ExprSchema] /// @@ -90,7 +90,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). - fn get_type(&self, schema: &DFSchema) -> Result { + fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { @@ -136,8 +136,7 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - let t = fun.return_type_from_exprs(args, schema)?; - Ok(t) + Ok(fun.return_type_from_exprs(args, schema)?) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -221,7 +220,7 @@ impl ExprSchemable for Expr { /// This function errors when it is not possible to compute its /// nullability. This happens when the expression refers to a /// column that does not exist in the schema. - fn nullable(&self, input_schema: &DFSchema) -> Result { + fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) @@ -328,9 +327,9 @@ impl ExprSchemable for Expr { } } - fn metadata(&self, schema: &DFSchema) -> Result> { + fn metadata(&self, schema: &dyn ExprSchema) -> Result> { match self { - Expr::Column(_) => Ok(schema.metadata().clone()), + Expr::Column(c) => Ok(schema.metadata(c)?.clone()), Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), _ => Ok(HashMap::new()), } @@ -340,7 +339,7 @@ impl ExprSchemable for Expr { /// /// So for example, a projected expression `col(c1) + col(c2)` is /// placed in an output field **named** col("c1 + c2") - fn to_field(&self, input_schema: &DFSchema) -> Result { + fn to_field(&self, input_schema: &dyn ExprSchema) -> Result { match self { Expr::Column(c) => Ok(DFField::new( c.relation.clone(), @@ -371,7 +370,7 @@ impl ExprSchemable for Expr { /// /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. - fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result { + fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { return Ok(self); @@ -398,7 +397,7 @@ impl ExprSchemable for Expr { fn field_for_index( expr: &Expr, field: &GetFieldAccess, - schema: &DFSchema, + schema: &dyn ExprSchema, ) -> Result { let expr_dt = expr.get_type(schema)?; match field { @@ -458,23 +457,21 @@ mod tests { use super::*; use crate::{col, lit}; use arrow::datatypes::{DataType, Fields}; - use datafusion_common::{ScalarValue, TableReference}; + use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; macro_rules! test_is_expr_nullable { ($EXPR_TYPE:ident) => {{ let expr = lit(ScalarValue::Null).$EXPR_TYPE(); - assert!(!expr.nullable(&MockExprSchema::new().into_schema()).unwrap()); + assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); }}; } #[test] fn expr_schema_nullability() { let expr = col("foo").eq(lit(1)); - let mock = MockExprSchema::new(); - - assert!(!expr.nullable(&mock.clone().into_schema()).unwrap()); + assert!(!expr.nullable(&MockExprSchema::new()).unwrap()); assert!(expr - .nullable(&mock.with_nullable(true).into_schema()) + .nullable(&MockExprSchema::new().with_nullable(true)) .unwrap()); test_is_expr_nullable!(is_null); @@ -493,7 +490,6 @@ mod tests { MockExprSchema::new() .with_data_type(DataType::Int32) .with_nullable(nullable) - .into_schema() }; let expr = col("foo").between(lit(1), lit(2)); @@ -518,16 +514,14 @@ mod tests { MockExprSchema::new() .with_data_type(DataType::Int32) .with_nullable(nullable) - .into_schema() }; let expr = col("foo").in_list(vec![lit(1); 5], false); assert!(!expr.nullable(&get_schema(false)).unwrap()); assert!(expr.nullable(&get_schema(true)).unwrap()); // Testing nullable() returns an error. - assert!(expr - .nullable(&MockExprSchema::new().with_name("blarg").into_schema()) + .nullable(&get_schema(false).with_error_on_nullable(true)) .is_err()); let null = lit(ScalarValue::Int32(None)); @@ -545,7 +539,6 @@ mod tests { MockExprSchema::new() .with_data_type(DataType::Utf8) .with_nullable(nullable) - .into_schema() }; let expr = col("foo").like(lit("bar")); @@ -561,7 +554,7 @@ mod tests { let expr = col("foo"); assert_eq!( DataType::Utf8, - expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8).into_schema()) + expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8)) .unwrap() ); } @@ -573,8 +566,7 @@ mod tests { let expr = col("foo"); let schema = MockExprSchema::new() .with_data_type(DataType::Int32) - .with_metadata(meta.clone()) - .into_schema(); + .with_metadata(meta.clone()); // col and alias should be metadata-preserving assert_eq!(meta, expr.metadata(&schema).unwrap()); @@ -593,7 +585,7 @@ mod tests { let schema = DFSchema::new_with_metadata( vec![DFField::new_unqualified("foo", DataType::Int32, true) .with_metadata(meta.clone())], - meta.clone(), + HashMap::new(), ) .unwrap(); @@ -622,29 +614,24 @@ mod tests { assert!(expr.nullable(&schema).unwrap()); } - #[derive(Debug, Clone)] + #[derive(Debug)] struct MockExprSchema { - name: String, nullable: bool, data_type: DataType, + error_on_nullable: bool, metadata: HashMap, } impl MockExprSchema { fn new() -> Self { Self { - name: "foo".to_string(), nullable: false, data_type: DataType::Null, + error_on_nullable: false, metadata: HashMap::new(), } } - fn with_name(mut self, name: impl Into) -> Self { - self.name = name.into(); - self - } - fn with_nullable(mut self, nullable: bool) -> Self { self.nullable = nullable; self @@ -655,19 +642,32 @@ mod tests { self } + fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self { + self.error_on_nullable = error_on_nullable; + self + } + fn with_metadata(mut self, metadata: HashMap) -> Self { self.metadata = metadata; self } + } + impl ExprSchema for MockExprSchema { + fn nullable(&self, _col: &Column) -> Result { + if self.error_on_nullable { + internal_err!("nullable error") + } else { + Ok(self.nullable) + } + } - /// Create a new schema with a single column - fn into_schema(self) -> DFSchema { - let Self {name, nullable, data_type, metadata} = self; + fn data_type(&self, _col: &Column) -> Result<&DataType> { + Ok(&self.data_type) + } - let field = DFField::new_unqualified(&name, data_type, nullable) - .with_metadata(metadata.clone()); - DFSchema::new_with_metadata(vec![field], metadata).unwrap() + fn metadata(&self, _col: &Column) -> Result<&HashMap> { + Ok(&self.metadata) } } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3f7f6a76bb24..b5613f7c2b2c 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,7 +23,7 @@ use crate::{ ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::{ExprSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; @@ -154,7 +154,7 @@ impl ScalarUDF { pub fn return_type_from_exprs( &self, args: &[Expr], - schema: &DFSchema, + schema: &dyn ExprSchema, ) -> Result { // If the implementation provides a return_type_from_exprs, use it self.inner.return_type_from_exprs(args, schema) @@ -285,7 +285,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn return_type_from_exprs( &self, args: &[Expr], - schema: &DFSchema, + schema: &dyn ExprSchema, ) -> Result { let arg_types = args .iter() From 4efb39560e4a452ab9c42e4f1390beb7dd2f63de Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 12 Feb 2024 19:17:18 -0500 Subject: [PATCH 12/17] fmt --- .../core/tests/user_defined/user_defined_scalar_functions.rs | 5 ++++- datafusion/expr/src/expr_schema.rs | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index a215f0dee73d..17edf6fbd428 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -22,7 +22,10 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::as_float64_array; -use datafusion_common::{assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, ExprSchema}; +use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, + plan_err, DataFusionError, ExprSchema, Result, ScalarValue, +}; use datafusion_expr::{ create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index d6032809cfda..ce97295b58a7 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -28,8 +28,8 @@ use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, Column, DFField, - DataFusionError, ExprSchema, Result, + internal_err, plan_datafusion_err, plan_err, Column, DFField, DataFusionError, + ExprSchema, Result, }; use std::collections::HashMap; use std::sync::Arc; From a9546eebf94d4ced263220bc35d88de3fa290397 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 12 Feb 2024 19:56:33 -0500 Subject: [PATCH 13/17] docs --- datafusion/expr/src/expr_schema.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index ce97295b58a7..491b4a852261 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -56,7 +56,9 @@ impl ExprSchemable for Expr { /// Returns the [arrow::datatypes::DataType] of the expression /// based on [ExprSchema] /// - /// Note: [DFSchema] implements [ExprSchema]. + /// Note: [`DFSchema`] implements [ExprSchema]. + /// + /// [`DFSchema`]: datafusion_common::DFSchema /// /// # Examples /// @@ -213,7 +215,9 @@ impl ExprSchemable for Expr { /// Returns the nullability of the expression based on [ExprSchema]. /// - /// Note: [DFSchema] implements [ExprSchema]. + /// Note: [`DFSchema`] implements [ExprSchema]. + /// + /// [`DFSchema`]: datafusion_common::DFSchema /// /// # Errors /// From 93b72ee98a7c6e2598421e0d9c55b4e965bded44 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 15 Feb 2024 07:10:18 -0500 Subject: [PATCH 14/17] Apply suggestions from code review Co-authored-by: Jeffrey Vo --- .../tests/user_defined/user_defined_scalar_functions.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 17edf6fbd428..9812789740f7 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -526,7 +526,7 @@ impl ScalarUDFImpl for TakeUDF { not_impl_err!("Not called because the return_type_from_exprs is implemented") } - /// Thus function returns the type of the first or second argument based on + /// This function returns the type of the first or second argument based on /// the third argument: /// /// 1. If the third argument is '0', return the type of the first argument @@ -558,7 +558,7 @@ impl ScalarUDFImpl for TakeUDF { arg_exprs.get(take_idx).unwrap().get_type(schema) } - // The actual implementation rethr + // The actual implementation fn invoke(&self, args: &[ColumnarValue]) -> Result { let take_idx = match &args[2] { ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, @@ -595,8 +595,8 @@ async fn verify_udf_return_type() -> Result<()> { let schema = df.schema(); // The output schema should be - // * type of column smallint_col (float64) - // * type of column double_col (float32) + // * type of column smallint_col (int32) + // * type of column double_col (float64) assert_eq!(schema.field(0).data_type(), &DataType::Int32); assert_eq!(schema.field(1).data_type(), &DataType::Float64); From 653577f9a52dc464e2ce3000e5b6a61f8c53fbe5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 15 Feb 2024 07:20:37 -0500 Subject: [PATCH 15/17] Add docs explaining what happens when both `return_type` and `return_type_from_exprs` are called --- datafusion/expr/src/udf.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index b5613f7c2b2c..0c8df3ba75b7 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,7 +23,7 @@ use crate::{ ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{ExprSchema, Result}; +use datafusion_common::{DataFusionError, ExprSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; @@ -254,7 +254,15 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn signature(&self) -> &Signature; /// What [`DataType`] will be returned by this function, given the types of - /// the arguments + /// the arguments. + /// + /// # Notes + /// + /// If you provide an implementation for [`Self::return_type_from_exprs`], + /// DataFusion will not call `return_type` (this function). In this case it + /// is recommended to return [`DataFusionError::Internal`]. + /// + /// [`DataFusionError::NotImplemented`]: datafusion_common::DataFusionError::Internal fn return_type(&self, arg_types: &[DataType]) -> Result; /// What [`DataType`] will be returned by this function, given the From 7993af8bc845b3e4aeb90436a4d8f2351bf731d6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 15 Feb 2024 08:44:45 -0500 Subject: [PATCH 16/17] clippy --- datafusion/expr/src/udf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 0c8df3ba75b7..9146ac656185 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -23,7 +23,7 @@ use crate::{ ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{DataFusionError, ExprSchema, Result}; +use datafusion_common::{ExprSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; From 2121770676c0a6c7fc1b161494237c30f83e4df2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 15 Feb 2024 09:04:42 -0500 Subject: [PATCH 17/17] fix doc -- comedy of errors --- datafusion/expr/src/udf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9146ac656185..5b5d92a628c2 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -262,7 +262,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// DataFusion will not call `return_type` (this function). In this case it /// is recommended to return [`DataFusionError::Internal`]. /// - /// [`DataFusionError::NotImplemented`]: datafusion_common::DataFusionError::Internal + /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal fn return_type(&self, arg_types: &[DataType]) -> Result; /// What [`DataType`] will be returned by this function, given the