From 4a33b9f9d4bb895e5ac98e0dcd512a38b8d08759 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 7 Sep 2024 23:02:18 +0100 Subject: [PATCH] address comments from dictionary encoding (#41) --- src/common.rs | 19 ++++++++++++------- src/json_as_text.rs | 4 ++-- src/json_contains.rs | 4 ++-- src/json_get.rs | 4 ++-- src/json_get_bool.rs | 4 ++-- src/json_get_float.rs | 4 ++-- src/json_get_int.rs | 4 ++-- src/json_get_json.rs | 4 ++-- src/json_get_str.rs | 4 ++-- src/json_length.rs | 4 ++-- 10 files changed, 30 insertions(+), 25 deletions(-) diff --git a/src/common.rs b/src/common.rs index d66f431..f9445c2 100644 --- a/src/common.rs +++ b/src/common.rs @@ -10,8 +10,15 @@ use jiter::{Jiter, JiterError, Peek}; use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array}; -/// check the args of a function return an error, return the return type -pub fn check_args(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult { +/// General implementation of `ScalarUDFImpl::return_type`. +/// +/// # Arguments +/// +/// * `args` - The arguments to the function +/// * `fn_name` - The name of the function +/// * `value_type` - The general return type of the function, might be wrapped in a dictionary depending +/// on the first argument +pub fn scalar_udf_return_type(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult { let Some(first) = args.first() else { return plan_err!("The '{fn_name}' function requires one or more arguments."); }; @@ -20,7 +27,7 @@ pub fn check_args(args: &[DataType], fn_name: &str, value_type: DataType) -> Dat // if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) { return plan_err!("Unexpected argument type to '{fn_name}' at position 1, expected a string, got {first:?}."); } - args[1..].iter().enumerate().try_for_each(|(index, arg)| { + args.iter().skip(1).enumerate().try_for_each(|(index, arg)| { if is_str(arg) || is_int(arg) || dict_key_type(arg).is_some() { Ok(()) } else { @@ -41,10 +48,8 @@ fn is_str(d: &DataType) -> bool { } fn is_int(d: &DataType) -> bool { - matches!( - d, - DataType::UInt64 | DataType::Int64 | DataType::UInt32 | DataType::Int32 - ) + // TODO we should support more types of int, but that's a longer task + matches!(d, DataType::UInt64 | DataType::Int64) } fn dict_key_type(d: &DataType) -> Option { diff --git a/src/json_as_text.rs b/src/json_as_text.rs index 954c7c6..dc0193a 100644 --- a/src/json_as_text.rs +++ b/src/json_as_text.rs @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonAsText { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name(), DataType::Utf8) + scalar_udf_return_type(arg_types, self.name(), DataType::Utf8) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_contains.rs b/src/json_contains.rs index 8a2d3bb..d5b6e34 100644 --- a/src/json_contains.rs +++ b/src/json_contains.rs @@ -6,7 +6,7 @@ use datafusion::common::arrow::array::{ArrayRef, BooleanArray}; use datafusion::common::{plan_err, Result, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use crate::common::{check_args, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -48,7 +48,7 @@ impl ScalarUDFImpl for JsonContains { if arg_types.len() < 2 { plan_err!("The 'json_contains' function requires two or more arguments.") } else { - check_args(arg_types, self.name(), DataType::Boolean) + scalar_udf_return_type(arg_types, self.name(), DataType::Boolean) } } diff --git a/src/json_get.rs b/src/json_get.rs index 1523160..0328893 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -8,7 +8,7 @@ use datafusion::common::Result as DataFusionResult; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::{Jiter, NumberAny, NumberInt, Peek}; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath}; use crate::common_macros::make_udf_function; use crate::common_union::{JsonUnion, JsonUnionField}; @@ -50,7 +50,7 @@ impl ScalarUDFImpl for JsonGet { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name(), JsonUnion::data_type()) + scalar_udf_return_type(arg_types, self.name(), JsonUnion::data_type()) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 8161786..dae742b 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetBool { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name(), DataType::Boolean) + scalar_udf_return_type(arg_types, self.name(), DataType::Boolean) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_float.rs b/src/json_get_float.rs index 469cbc8..31b92d2 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberAny, Peek}; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetFloat { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name(), DataType::Float64) + scalar_udf_return_type(arg_types, self.name(), DataType::Float64) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_int.rs b/src/json_get_int.rs index 1cf3ce9..53d91c4 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::{NumberInt, Peek}; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetInt { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name(), DataType::Int64) + scalar_udf_return_type(arg_types, self.name(), DataType::Int64) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_json.rs b/src/json_get_json.rs index 8871a0d..f941d8c 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -6,7 +6,7 @@ use datafusion::arrow::datatypes::DataType; use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -45,7 +45,7 @@ impl ScalarUDFImpl for JsonGetJson { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name(), DataType::Utf8) + scalar_udf_return_type(arg_types, self.name(), DataType::Utf8) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_str.rs b/src/json_get_str.rs index 6a810bf..2f59791 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetStr { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name(), DataType::Utf8) + scalar_udf_return_type(arg_types, self.name(), DataType::Utf8) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_length.rs b/src/json_length.rs index 6239190..54e0b3d 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use jiter::Peek; -use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath}; use crate::common_macros::make_udf_function; make_udf_function!( @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonLength { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name(), DataType::UInt64) + scalar_udf_return_type(arg_types, self.name(), DataType::UInt64) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult {