Skip to content

Commit

Permalink
Make fields of ScalarUDF non pub
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Nov 3, 2023
1 parent c2e7680 commit c67e620
Show file tree
Hide file tree
Showing 14 changed files with 57 additions and 31 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
}
}
Expr::ScalarUDF(ScalarUDF { fun, .. }) => {
match fun.signature.volatility {
match fun.signature().volatility {
Volatility::Immutable => VisitRecursion::Continue,
// TODO: Stable functions could be `applicable`, but that would require access to the context
Volatility::Stable | Volatility::Volatile => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ impl SessionContext {
self.state
.write()
.scalar_functions
.insert(f.name.clone(), Arc::new(f));
.insert(f.name().to_string(), Arc::new(f));
}

/// Registers an aggregate UDF within this context.
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
create_function_physical_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_physical_name(&fun.name, false, args)
create_function_physical_name(fun.name(), false, args)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
create_function_physical_name(&fun.to_string(), false, args)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,7 @@ impl fmt::Display for Expr {
fmt_function(f, &func.fun.to_string(), false, &func.args, true)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
fmt_function(f, &fun.name, false, args, true)
fmt_function(f, fun.name(), false, args, true)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down Expand Up @@ -1512,7 +1512,7 @@ fn create_name(e: &Expr) -> Result<String> {
create_function_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => {
create_function_name(&fun.name, false, args)
create_function_name(fun.name(), false, args)
}
Expr::WindowFunction(WindowFunction {
fun,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl ExprSchemable for Expr {
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
Ok(fun.return_type(&data_types)?)
}
Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let data_types = args
Expand Down
45 changes: 36 additions & 9 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,31 @@
// specific language governing permissions and limitations
// under the License.

//! Udf module contains foundational types that are used to represent UDFs in DataFusion.
//! [`ScalarUDF`]: Scalar User Defined Functions

use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::sync::Arc;

/// Logical representation of a UDF.
/// Logical representation of a Scalar User Defined Function.
///
/// A scalar function produces a single row output for each row of input.
///
/// This struct contains the information DataFusion needs to plan and invoke
/// functions such name, type signature, return type, and actual implementation.
///
#[derive(Clone)]
pub struct ScalarUDF {
/// name
pub name: String,
/// signature
pub signature: Signature,
/// Return type
pub return_type: ReturnTypeFunction,
/// The name of the function
name: String,
/// The signature (the types of arguments that are supported)
signature: Signature,
/// Function that returns the return type given the argument types
return_type: ReturnTypeFunction,
/// actual implementation
///
/// The fn param is the wrapped function but be aware that the function will
Expand All @@ -40,7 +48,7 @@ pub struct ScalarUDF {
/// will be passed. In that case the single element is a null array to indicate
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
pub fun: ScalarFunctionImplementation,
fun: ScalarFunctionImplementation,
}

impl Debug for ScalarUDF {
Expand Down Expand Up @@ -89,4 +97,23 @@ impl ScalarUDF {
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()), args))
}

/// Returns this function's name
pub fn name(&self) -> &str {
&self.name
}
/// Returns this function's signature
pub fn signature(&self) -> &Signature {
&self.signature
}
/// return the return type of this function given the types of the arguments
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
}
/// return the implementation of this function
pub fn fun(&self) -> &ScalarFunctionImplementation {
&self.fun
}
}
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
fun.signature(),
)?;
Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ impl<'a> ConstEvaluator<'a> {
Self::volatility_ok(fun.volatility())
}
Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => {
Self::volatility_ok(fun.signature.volatility)
Self::volatility_ok(fun.signature().volatility)
}
Expr::Literal(_)
| Expr::BinaryExpr { .. }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ pub fn create_physical_expr(
&format!("{fun}"),
fun_expr,
input_phy_exprs.to_vec(),
&data_type,
data_type,
monotonicity,
)))
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ impl ScalarFunctionExpr {
name: &str,
fun: ScalarFunctionImplementation,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: &DataType,
return_type: DataType,
monotonicity: Option<FuncMonotonicity>,
) -> Self {
Self {
fun,
name: name.to_owned(),
args,
return_type: return_type.clone(),
return_type,
monotonicity,
}
}
Expand Down Expand Up @@ -165,7 +165,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
&self.name,
self.fun.clone(),
children,
self.return_type(),
self.return_type().clone(),
self.monotonicity.clone(),
)))
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ pub fn create_physical_expr(
.collect::<Result<Vec<_>>>()?;

Ok(Arc::new(ScalarFunctionExpr::new(
&fun.name,
fun.fun.clone(),
fun.name(),
fun.fun().clone(),
input_phy_exprs.to_vec(),
(fun.return_type)(&input_exprs_types)?.as_ref(),
fun.return_type(&input_exprs_types)?,
None,
)))
}
2 changes: 1 addition & 1 deletion datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
}
Expr::ScalarUDF(ScalarUDF { fun, args }) => Self {
expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
fun_name: fun.name.clone(),
fun_name: fun.name().to_string(),
args: args
.iter()
.map(|expr| expr.try_into())
Expand Down
7 changes: 3 additions & 4 deletions datafusion/proto/src/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
//! Serde code to convert from protocol buffers to Rust data structures.

use std::convert::{TryFrom, TryInto};
use std::ops::Deref;
use std::sync::Arc;

use arrow::compute::SortOptions;
Expand Down Expand Up @@ -308,12 +307,12 @@ pub fn parse_physical_expr(
&e.name,
fun_expr,
args,
&convert_required!(e.return_type)?,
convert_required!(e.return_type)?,
None,
))
}
ExprType::ScalarUdf(e) => {
let scalar_fun = registry.udf(e.name.as_str())?.deref().clone().fun;
let scalar_fun = registry.udf(e.name.as_str())?.fun().clone();

let args = e
.args
Expand All @@ -325,7 +324,7 @@ pub fn parse_physical_expr(
e.name.as_str(),
scalar_fun,
args,
&convert_required!(e.return_type)?,
convert_required!(e.return_type)?,
None,
))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> {
"acos",
fun_expr,
vec![col("a", &schema)?],
&DataType::Int64,
DataType::Int64,
None,
);

Expand Down Expand Up @@ -549,7 +549,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
"dummy",
scalar_fn,
vec![col("a", &schema)?],
&DataType::Int64,
DataType::Int64,
None,
);

Expand Down

0 comments on commit c67e620

Please sign in to comment.