Skip to content

Commit

Permalink
Support dict encoded arrays and values (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Sep 7, 2024
1 parent e50cd3f commit 33a096e
Show file tree
Hide file tree
Showing 12 changed files with 278 additions and 65 deletions.
187 changes: 133 additions & 54 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,59 @@
use std::str::Utf8Error;

use datafusion::arrow::array::{Array, ArrayRef, Int64Array, LargeStringArray, StringArray, UInt64Array};
use datafusion::arrow::array::{
Array, ArrayRef, AsArray, Int64Array, LargeStringArray, StringArray, StringViewArray, UInt64Array,
};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};

use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array};

pub fn check_args(args: &[DataType], fn_name: &str) -> DataFusionResult<()> {
/// check the args of a function return an error, return the return type
pub fn check_args(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult<DataType> {
let Some(first) = args.first() else {
return plan_err!("The '{fn_name}' function requires one or more arguments.");
};
if !(matches!(first, DataType::Utf8 | DataType::LargeUtf8) || is_json_union(first)) {
let first_dict_key_type = dict_key_type(first);
if !(is_str(first) || is_json_union(first) || first_dict_key_type.is_some()) {
// if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) {
return plan_err!("Unexpected argument type to '{fn_name}' at position 1, expected a string, got {first:?}.");
}
args[1..].iter().enumerate().try_for_each(|(index, arg)| match arg {
DataType::Utf8 | DataType::LargeUtf8 | DataType::UInt64 | DataType::Int64 => Ok(()),
t => plan_err!(
"Unexpected argument type to '{fn_name}' at position {}, expected string or int, got {t:?}.",
index + 2
),
})
args[1..].iter().enumerate().try_for_each(|(index, arg)| {
if is_str(arg) || is_int(arg) || dict_key_type(arg).is_some() {
Ok(())
} else {
plan_err!(
"Unexpected argument type to '{fn_name}' at position {}, expected string or int, got {arg:?}.",
index + 2
)
}
})?;
match first_dict_key_type {
Some(t) => Ok(DataType::Dictionary(Box::new(t), Box::new(value_type))),
None => Ok(value_type),
}
}

fn is_str(d: &DataType) -> bool {
matches!(d, DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View)
}

fn is_int(d: &DataType) -> bool {
matches!(
d,
DataType::UInt64 | DataType::Int64 | DataType::UInt32 | DataType::Int32
)
}

fn dict_key_type(d: &DataType) -> Option<DataType> {
if let DataType::Dictionary(key, value) = d {
if is_str(value) || is_json_union(value) {
return Some(*key.clone());
}
}
None
}

#[derive(Debug)]
Expand Down Expand Up @@ -73,63 +104,77 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
};
match first_arg {
ColumnarValue::Array(json_array) => {
let result_collect = match args.get(1) {
let array = match args.get(1) {
Some(ColumnarValue::Array(a)) => {
if args.len() > 2 {
// TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23
exec_err!("More than 1 path element is not supported when querying JSON using an array.")
} else if let Some(str_path_array) = a.as_any().downcast_ref::<StringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, jiter_find, true)
} else if let Some(str_path_array) = a.as_any().downcast_ref::<LargeStringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, jiter_find, true)
} else if let Some(int_path_array) = a.as_any().downcast_ref::<Int64Array>() {
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
zip_apply(json_array, paths, jiter_find, false)
} else if let Some(int_path_array) = a.as_any().downcast_ref::<UInt64Array>() {
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
zip_apply(json_array, paths, jiter_find, false)
} else {
exec_err!("unexpected second argument type, expected string or int array")
invoke_array(json_array, a, to_array, jiter_find)
}
}
Some(ColumnarValue::Scalar(_)) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find),
None => scalar_apply(json_array, &[], jiter_find),
Some(ColumnarValue::Scalar(_)) => {
scalar_apply(json_array, &JsonPath::extract_path(args), to_array, jiter_find)
}
None => scalar_apply(json_array, &[], to_array, jiter_find),
};
to_array(result_collect?).map(ColumnarValue::from)
}
ColumnarValue::Scalar(ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s)) => {
let path = JsonPath::extract_path(args);
let v = jiter_find(s.as_ref().map(String::as_str), &path).ok();
Ok(ColumnarValue::Scalar(to_scalar(v)))
}
ColumnarValue::Scalar(ScalarValue::Union(type_id_value, union_fields, _)) => {
let opt_json = json_from_union_scalar(type_id_value, union_fields);
let v = jiter_find(opt_json, &JsonPath::extract_path(args)).ok();
Ok(ColumnarValue::Scalar(to_scalar(v)))
}
ColumnarValue::Scalar(_) => {
exec_err!("unexpected first argument type, expected string or JSON union")
array.map(ColumnarValue::from)
}
ColumnarValue::Scalar(s) => invoke_scalar(s, args, jiter_find, to_scalar),
}
}

fn invoke_array<C: FromIterator<Option<I>> + 'static, I>(
json_array: &ArrayRef,
needle_array: &ArrayRef,
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> DataFusionResult<ArrayRef> {
if let Some(d) = needle_array.as_any_dictionary_opt() {
invoke_array(json_array, d.values(), to_array, jiter_find)
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<StringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, to_array, jiter_find, true)
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<LargeStringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, to_array, jiter_find, true)
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<StringViewArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, to_array, jiter_find, true)
} else if let Some(int_path_array) = needle_array.as_any().downcast_ref::<Int64Array>() {
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
zip_apply(json_array, paths, to_array, jiter_find, false)
} else if let Some(int_path_array) = needle_array.as_any().downcast_ref::<UInt64Array>() {
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
zip_apply(json_array, paths, to_array, jiter_find, false)
} else {
exec_err!("unexpected second argument type, expected string or int array")
}
}

fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
json_array: &ArrayRef,
path_array: P,
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
object_lookup: bool,
) -> DataFusionResult<C> {
if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
Ok(zip_apply_iter(string_array.iter(), path_array, jiter_find))
) -> DataFusionResult<ArrayRef> {
if let Some(d) = json_array.as_any_dictionary_opt() {
let a = zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup)?;
return Ok(d.with_values(a).into());
}
let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
Ok(zip_apply_iter(large_string_array.iter(), path_array, jiter_find))
zip_apply_iter(large_string_array.iter(), path_array, jiter_find)
} else if let Some(string_view) = json_array.as_any().downcast_ref::<StringViewArray>() {
zip_apply_iter(string_view.iter(), path_array, jiter_find)
} else if let Some(string_array) = nested_json_array(json_array, object_lookup) {
Ok(zip_apply_iter(string_array.iter(), path_array, jiter_find))
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else {
exec_err!("unexpected json array type {:?}", json_array.data_type())
}
return exec_err!("unexpected json array type {:?}", json_array.data_type());
};
to_array(c)
}

fn zip_apply_iter<'a, 'j, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
Expand All @@ -149,20 +194,54 @@ fn zip_apply_iter<'a, 'j, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromItera
.collect::<C>()
}

fn invoke_scalar<I>(
scalar: &ScalarValue,
args: &[ColumnarValue],
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
to_scalar: impl Fn(Option<I>) -> ScalarValue,
) -> DataFusionResult<ColumnarValue> {
match scalar {
ScalarValue::Dictionary(_, b) => invoke_scalar(b.as_ref(), args, jiter_find, to_scalar),
ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => {
let path = JsonPath::extract_path(args);
let v = jiter_find(s.as_ref().map(String::as_str), &path).ok();
Ok(ColumnarValue::Scalar(to_scalar(v)))
}
ScalarValue::Union(type_id_value, union_fields, _) => {
let opt_json = json_from_union_scalar(type_id_value, union_fields);
let v = jiter_find(opt_json, &JsonPath::extract_path(args)).ok();
Ok(ColumnarValue::Scalar(to_scalar(v)))
}
_ => {
exec_err!("unexpected first argument type, expected string or JSON union")
}
}
}

fn scalar_apply<C: FromIterator<Option<I>>, I>(
json_array: &ArrayRef,
path: &[JsonPath],
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> DataFusionResult<C> {
if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
Ok(scalar_apply_iter(string_array.iter(), path, jiter_find))
) -> DataFusionResult<ArrayRef> {
if let Some(d) = json_array.as_any_dictionary_opt() {
let a = scalar_apply(d.values(), path, to_array, jiter_find)?;
return Ok(d.with_values(a).into());
}

let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
Ok(scalar_apply_iter(large_string_array.iter(), path, jiter_find))
scalar_apply_iter(large_string_array.iter(), path, jiter_find)
} else if let Some(string_view_array) = json_array.as_any().downcast_ref::<StringViewArray>() {
scalar_apply_iter(string_view_array.iter(), path, jiter_find)
} else if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
Ok(scalar_apply_iter(string_array.iter(), path, jiter_find))
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else {
exec_err!("unexpected json array type {:?}", json_array.data_type())
}
return exec_err!("unexpected json array type {:?}", json_array.data_type());
};

to_array(c)
}

fn is_object_lookup(path: &[JsonPath]) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion src/json_as_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonAsText {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|()| DataType::Utf8)
check_args(arg_types, self.name(), DataType::Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
2 changes: 1 addition & 1 deletion src/json_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl ScalarUDFImpl for JsonContains {
if arg_types.len() < 2 {
plan_err!("The 'json_contains' function requires two or more arguments.")
} else {
check_args(arg_types, self.name()).map(|()| DataType::Boolean)
check_args(arg_types, self.name(), DataType::Boolean)
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/json_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl ScalarUDFImpl for JsonGet {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|()| JsonUnion::data_type())
check_args(arg_types, self.name(), JsonUnion::data_type())
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
2 changes: 1 addition & 1 deletion src/json_get_bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetBool {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|()| DataType::Boolean)
check_args(arg_types, self.name(), DataType::Boolean)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
2 changes: 1 addition & 1 deletion src/json_get_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetFloat {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|()| DataType::Float64)
check_args(arg_types, self.name(), DataType::Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
2 changes: 1 addition & 1 deletion src/json_get_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetInt {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|()| DataType::Int64)
check_args(arg_types, self.name(), DataType::Int64)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
2 changes: 1 addition & 1 deletion src/json_get_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl ScalarUDFImpl for JsonGetJson {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|()| DataType::Utf8)
check_args(arg_types, self.name(), DataType::Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
2 changes: 1 addition & 1 deletion src/json_get_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetStr {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|()| DataType::Utf8)
check_args(arg_types, self.name(), DataType::Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
2 changes: 1 addition & 1 deletion src/json_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonLength {
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|()| DataType::UInt64)
check_args(arg_types, self.name(), DataType::UInt64)
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
Expand Down
Loading

0 comments on commit 33a096e

Please sign in to comment.