Skip to content

Commit

Permalink
address comments from dictionary encoding (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Sep 7, 2024
1 parent 33a096e commit 4a33b9f
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 25 deletions.
19 changes: 12 additions & 7 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,15 @@ use jiter::{Jiter, JiterError, Peek};

use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array};

/// check the args of a function return an error, return the return type
pub fn check_args(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult<DataType> {
/// General implementation of `ScalarUDFImpl::return_type`.
///
/// # Arguments
///
/// * `args` - The arguments to the function
/// * `fn_name` - The name of the function
/// * `value_type` - The general return type of the function, might be wrapped in a dictionary depending
/// on the first argument
pub fn scalar_udf_return_type(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult<DataType> {
let Some(first) = args.first() else {
return plan_err!("The '{fn_name}' function requires one or more arguments.");
};
Expand All @@ -20,7 +27,7 @@ pub fn check_args(args: &[DataType], fn_name: &str, value_type: DataType) -> Dat
// if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) {
return plan_err!("Unexpected argument type to '{fn_name}' at position 1, expected a string, got {first:?}.");
}
args[1..].iter().enumerate().try_for_each(|(index, arg)| {
args.iter().skip(1).enumerate().try_for_each(|(index, arg)| {
if is_str(arg) || is_int(arg) || dict_key_type(arg).is_some() {
Ok(())
} else {
Expand All @@ -41,10 +48,8 @@ fn is_str(d: &DataType) -> bool {
}

fn is_int(d: &DataType) -> bool {
matches!(
d,
DataType::UInt64 | DataType::Int64 | DataType::UInt32 | DataType::Int32
)
// TODO we should support more types of int, but that's a longer task
matches!(d, DataType::UInt64 | DataType::Int64)
}

fn dict_key_type(d: &DataType) -> Option<DataType> {
Expand Down
4 changes: 2 additions & 2 deletions src/json_as_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath};
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonAsText {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name(), DataType::Utf8)
scalar_udf_return_type(arg_types, self.name(), DataType::Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
4 changes: 2 additions & 2 deletions src/json_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use datafusion::common::arrow::array::{ArrayRef, BooleanArray};
use datafusion::common::{plan_err, Result, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

use crate::common::{check_args, invoke, jiter_json_find, GetError, JsonPath};
use crate::common::{invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -48,7 +48,7 @@ impl ScalarUDFImpl for JsonContains {
if arg_types.len() < 2 {
plan_err!("The 'json_contains' function requires two or more arguments.")
} else {
check_args(arg_types, self.name(), DataType::Boolean)
scalar_udf_return_type(arg_types, self.name(), DataType::Boolean)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/json_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use datafusion::common::Result as DataFusionResult;
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::{Jiter, NumberAny, NumberInt, Peek};

use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath};
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
use crate::common_macros::make_udf_function;
use crate::common_union::{JsonUnion, JsonUnionField};

Expand Down Expand Up @@ -50,7 +50,7 @@ impl ScalarUDFImpl for JsonGet {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name(), JsonUnion::data_type())
scalar_udf_return_type(arg_types, self.name(), JsonUnion::data_type())
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
4 changes: 2 additions & 2 deletions src/json_get_bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath};
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetBool {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name(), DataType::Boolean)
scalar_udf_return_type(arg_types, self.name(), DataType::Boolean)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
4 changes: 2 additions & 2 deletions src/json_get_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::{NumberAny, Peek};

use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath};
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetFloat {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name(), DataType::Float64)
scalar_udf_return_type(arg_types, self.name(), DataType::Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
4 changes: 2 additions & 2 deletions src/json_get_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::{NumberInt, Peek};

use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath};
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetInt {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name(), DataType::Int64)
scalar_udf_return_type(arg_types, self.name(), DataType::Int64)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
4 changes: 2 additions & 2 deletions src/json_get_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath};
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -45,7 +45,7 @@ impl ScalarUDFImpl for JsonGetJson {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name(), DataType::Utf8)
scalar_udf_return_type(arg_types, self.name(), DataType::Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
4 changes: 2 additions & 2 deletions src/json_get_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath};
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetStr {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name(), DataType::Utf8)
scalar_udf_return_type(arg_types, self.name(), DataType::Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
4 changes: 2 additions & 2 deletions src/json_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath};
use crate::common::{get_err, invoke, jiter_json_find, scalar_udf_return_type, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonLength {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name(), DataType::UInt64)
scalar_udf_return_type(arg_types, self.name(), DataType::UInt64)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down

0 comments on commit 4a33b9f

Please sign in to comment.