From 33a096efe0ef85c2a42a844e86600ad9426a7915 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 7 Sep 2024 22:38:04 +0100 Subject: [PATCH] Support dict encoded arrays and values (#39) --- src/common.rs | 187 ++++++++++++++++++++++++++++++------------ src/json_as_text.rs | 2 +- src/json_contains.rs | 2 +- src/json_get.rs | 2 +- src/json_get_bool.rs | 2 +- src/json_get_float.rs | 2 +- src/json_get_int.rs | 2 +- src/json_get_json.rs | 2 +- src/json_get_str.rs | 2 +- src/json_length.rs | 2 +- tests/main.rs | 69 ++++++++++++++++ tests/utils/mod.rs | 69 +++++++++++++++- 12 files changed, 278 insertions(+), 65 deletions(-) diff --git a/src/common.rs b/src/common.rs index e3d6d7c..d66f431 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,6 +1,8 @@ use std::str::Utf8Error; -use datafusion::arrow::array::{Array, ArrayRef, Int64Array, LargeStringArray, StringArray, UInt64Array}; +use datafusion::arrow::array::{ + Array, ArrayRef, AsArray, Int64Array, LargeStringArray, StringArray, StringViewArray, UInt64Array, +}; use datafusion::arrow::datatypes::DataType; use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::ColumnarValue; @@ -8,21 +10,50 @@ use jiter::{Jiter, JiterError, Peek}; use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array}; -pub fn check_args(args: &[DataType], fn_name: &str) -> DataFusionResult<()> { +/// check the args of a function return an error, return the return type +pub fn check_args(args: &[DataType], fn_name: &str, value_type: DataType) -> DataFusionResult { let Some(first) = args.first() else { return plan_err!("The '{fn_name}' function requires one or more arguments."); }; - if !(matches!(first, DataType::Utf8 | DataType::LargeUtf8) || is_json_union(first)) { + let first_dict_key_type = dict_key_type(first); + if !(is_str(first) || is_json_union(first) || first_dict_key_type.is_some()) { // if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) { return plan_err!("Unexpected argument type to '{fn_name}' at position 1, expected a string, got {first:?}."); } - args[1..].iter().enumerate().try_for_each(|(index, arg)| match arg { - DataType::Utf8 | DataType::LargeUtf8 | DataType::UInt64 | DataType::Int64 => Ok(()), - t => plan_err!( - "Unexpected argument type to '{fn_name}' at position {}, expected string or int, got {t:?}.", - index + 2 - ), - }) + args[1..].iter().enumerate().try_for_each(|(index, arg)| { + if is_str(arg) || is_int(arg) || dict_key_type(arg).is_some() { + Ok(()) + } else { + plan_err!( + "Unexpected argument type to '{fn_name}' at position {}, expected string or int, got {arg:?}.", + index + 2 + ) + } + })?; + match first_dict_key_type { + Some(t) => Ok(DataType::Dictionary(Box::new(t), Box::new(value_type))), + None => Ok(value_type), + } +} + +fn is_str(d: &DataType) -> bool { + matches!(d, DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View) +} + +fn is_int(d: &DataType) -> bool { + matches!( + d, + DataType::UInt64 | DataType::Int64 | DataType::UInt32 | DataType::Int32 + ) +} + +fn dict_key_type(d: &DataType) -> Option { + if let DataType::Dictionary(key, value) = d { + if is_str(value) || is_json_union(value) { + return Some(*key.clone()); + } + } + None } #[derive(Debug)] @@ -73,63 +104,77 @@ pub fn invoke> + 'static, I>( }; match first_arg { ColumnarValue::Array(json_array) => { - let result_collect = match args.get(1) { + let array = match args.get(1) { Some(ColumnarValue::Array(a)) => { if args.len() > 2 { // TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23 exec_err!("More than 1 path element is not supported when querying JSON using an array.") - } else if let Some(str_path_array) = a.as_any().downcast_ref::() { - let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); - zip_apply(json_array, paths, jiter_find, true) - } else if let Some(str_path_array) = a.as_any().downcast_ref::() { - let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); - zip_apply(json_array, paths, jiter_find, true) - } else if let Some(int_path_array) = a.as_any().downcast_ref::() { - let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into)); - zip_apply(json_array, paths, jiter_find, false) - } else if let Some(int_path_array) = a.as_any().downcast_ref::() { - let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into)); - zip_apply(json_array, paths, jiter_find, false) } else { - exec_err!("unexpected second argument type, expected string or int array") + invoke_array(json_array, a, to_array, jiter_find) } } - Some(ColumnarValue::Scalar(_)) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find), - None => scalar_apply(json_array, &[], jiter_find), + Some(ColumnarValue::Scalar(_)) => { + scalar_apply(json_array, &JsonPath::extract_path(args), to_array, jiter_find) + } + None => scalar_apply(json_array, &[], to_array, jiter_find), }; - to_array(result_collect?).map(ColumnarValue::from) - } - ColumnarValue::Scalar(ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s)) => { - let path = JsonPath::extract_path(args); - let v = jiter_find(s.as_ref().map(String::as_str), &path).ok(); - Ok(ColumnarValue::Scalar(to_scalar(v))) - } - ColumnarValue::Scalar(ScalarValue::Union(type_id_value, union_fields, _)) => { - let opt_json = json_from_union_scalar(type_id_value, union_fields); - let v = jiter_find(opt_json, &JsonPath::extract_path(args)).ok(); - Ok(ColumnarValue::Scalar(to_scalar(v))) - } - ColumnarValue::Scalar(_) => { - exec_err!("unexpected first argument type, expected string or JSON union") + array.map(ColumnarValue::from) } + ColumnarValue::Scalar(s) => invoke_scalar(s, args, jiter_find, to_scalar), + } +} + +fn invoke_array> + 'static, I>( + json_array: &ArrayRef, + needle_array: &ArrayRef, + to_array: impl Fn(C) -> DataFusionResult, + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, +) -> DataFusionResult { + if let Some(d) = needle_array.as_any_dictionary_opt() { + invoke_array(json_array, d.values(), to_array, jiter_find) + } else if let Some(str_path_array) = needle_array.as_any().downcast_ref::() { + let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); + zip_apply(json_array, paths, to_array, jiter_find, true) + } else if let Some(str_path_array) = needle_array.as_any().downcast_ref::() { + let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); + zip_apply(json_array, paths, to_array, jiter_find, true) + } else if let Some(str_path_array) = needle_array.as_any().downcast_ref::() { + let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); + zip_apply(json_array, paths, to_array, jiter_find, true) + } else if let Some(int_path_array) = needle_array.as_any().downcast_ref::() { + let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into)); + zip_apply(json_array, paths, to_array, jiter_find, false) + } else if let Some(int_path_array) = needle_array.as_any().downcast_ref::() { + let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into)); + zip_apply(json_array, paths, to_array, jiter_find, false) + } else { + exec_err!("unexpected second argument type, expected string or int array") } } fn zip_apply<'a, P: Iterator>>, C: FromIterator> + 'static, I>( json_array: &ArrayRef, path_array: P, + to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, object_lookup: bool, -) -> DataFusionResult { - if let Some(string_array) = json_array.as_any().downcast_ref::() { - Ok(zip_apply_iter(string_array.iter(), path_array, jiter_find)) +) -> DataFusionResult { + if let Some(d) = json_array.as_any_dictionary_opt() { + let a = zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup)?; + return Ok(d.with_values(a).into()); + } + let c = if let Some(string_array) = json_array.as_any().downcast_ref::() { + zip_apply_iter(string_array.iter(), path_array, jiter_find) } else if let Some(large_string_array) = json_array.as_any().downcast_ref::() { - Ok(zip_apply_iter(large_string_array.iter(), path_array, jiter_find)) + zip_apply_iter(large_string_array.iter(), path_array, jiter_find) + } else if let Some(string_view) = json_array.as_any().downcast_ref::() { + zip_apply_iter(string_view.iter(), path_array, jiter_find) } else if let Some(string_array) = nested_json_array(json_array, object_lookup) { - Ok(zip_apply_iter(string_array.iter(), path_array, jiter_find)) + zip_apply_iter(string_array.iter(), path_array, jiter_find) } else { - exec_err!("unexpected json array type {:?}", json_array.data_type()) - } + return exec_err!("unexpected json array type {:?}", json_array.data_type()); + }; + to_array(c) } fn zip_apply_iter<'a, 'j, P: Iterator>>, C: FromIterator> + 'static, I>( @@ -149,20 +194,54 @@ fn zip_apply_iter<'a, 'j, P: Iterator>>, C: FromItera .collect::() } +fn invoke_scalar( + scalar: &ScalarValue, + args: &[ColumnarValue], + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, + to_scalar: impl Fn(Option) -> ScalarValue, +) -> DataFusionResult { + match scalar { + ScalarValue::Dictionary(_, b) => invoke_scalar(b.as_ref(), args, jiter_find, to_scalar), + ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => { + let path = JsonPath::extract_path(args); + let v = jiter_find(s.as_ref().map(String::as_str), &path).ok(); + Ok(ColumnarValue::Scalar(to_scalar(v))) + } + ScalarValue::Union(type_id_value, union_fields, _) => { + let opt_json = json_from_union_scalar(type_id_value, union_fields); + let v = jiter_find(opt_json, &JsonPath::extract_path(args)).ok(); + Ok(ColumnarValue::Scalar(to_scalar(v))) + } + _ => { + exec_err!("unexpected first argument type, expected string or JSON union") + } + } +} + fn scalar_apply>, I>( json_array: &ArrayRef, path: &[JsonPath], + to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, -) -> DataFusionResult { - if let Some(string_array) = json_array.as_any().downcast_ref::() { - Ok(scalar_apply_iter(string_array.iter(), path, jiter_find)) +) -> DataFusionResult { + if let Some(d) = json_array.as_any_dictionary_opt() { + let a = scalar_apply(d.values(), path, to_array, jiter_find)?; + return Ok(d.with_values(a).into()); + } + + let c = if let Some(string_array) = json_array.as_any().downcast_ref::() { + scalar_apply_iter(string_array.iter(), path, jiter_find) } else if let Some(large_string_array) = json_array.as_any().downcast_ref::() { - Ok(scalar_apply_iter(large_string_array.iter(), path, jiter_find)) + scalar_apply_iter(large_string_array.iter(), path, jiter_find) + } else if let Some(string_view_array) = json_array.as_any().downcast_ref::() { + scalar_apply_iter(string_view_array.iter(), path, jiter_find) } else if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) { - Ok(scalar_apply_iter(string_array.iter(), path, jiter_find)) + scalar_apply_iter(string_array.iter(), path, jiter_find) } else { - exec_err!("unexpected json array type {:?}", json_array.data_type()) - } + return exec_err!("unexpected json array type {:?}", json_array.data_type()); + }; + + to_array(c) } fn is_object_lookup(path: &[JsonPath]) -> bool { diff --git a/src/json_as_text.rs b/src/json_as_text.rs index 19bcb56..954c7c6 100644 --- a/src/json_as_text.rs +++ b/src/json_as_text.rs @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonAsText { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Utf8) + check_args(arg_types, self.name(), DataType::Utf8) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_contains.rs b/src/json_contains.rs index c9b9f44..8a2d3bb 100644 --- a/src/json_contains.rs +++ b/src/json_contains.rs @@ -48,7 +48,7 @@ impl ScalarUDFImpl for JsonContains { if arg_types.len() < 2 { plan_err!("The 'json_contains' function requires two or more arguments.") } else { - check_args(arg_types, self.name()).map(|()| DataType::Boolean) + check_args(arg_types, self.name(), DataType::Boolean) } } diff --git a/src/json_get.rs b/src/json_get.rs index 84d1593..1523160 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -50,7 +50,7 @@ impl ScalarUDFImpl for JsonGet { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| JsonUnion::data_type()) + check_args(arg_types, self.name(), JsonUnion::data_type()) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_bool.rs b/src/json_get_bool.rs index 511d9b9..8161786 100644 --- a/src/json_get_bool.rs +++ b/src/json_get_bool.rs @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetBool { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Boolean) + check_args(arg_types, self.name(), DataType::Boolean) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_float.rs b/src/json_get_float.rs index 16fc4f9..469cbc8 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetFloat { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Float64) + check_args(arg_types, self.name(), DataType::Float64) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_int.rs b/src/json_get_int.rs index e657d76..1cf3ce9 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetInt { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Int64) + check_args(arg_types, self.name(), DataType::Int64) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_json.rs b/src/json_get_json.rs index 66c1d3d..8871a0d 100644 --- a/src/json_get_json.rs +++ b/src/json_get_json.rs @@ -45,7 +45,7 @@ impl ScalarUDFImpl for JsonGetJson { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Utf8) + check_args(arg_types, self.name(), DataType::Utf8) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_get_str.rs b/src/json_get_str.rs index ddd0e66..6a810bf 100644 --- a/src/json_get_str.rs +++ b/src/json_get_str.rs @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonGetStr { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::Utf8) + check_args(arg_types, self.name(), DataType::Utf8) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/src/json_length.rs b/src/json_length.rs index 1f9ad09..6239190 100644 --- a/src/json_length.rs +++ b/src/json_length.rs @@ -46,7 +46,7 @@ impl ScalarUDFImpl for JsonLength { } fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { - check_args(arg_types, self.name()).map(|()| DataType::UInt64) + check_args(arg_types, self.name(), DataType::UInt64) } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { diff --git a/tests/main.rs b/tests/main.rs index 91df2c2..1d17f8a 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1138,3 +1138,72 @@ async fn test_arrow_cast_numeric() { let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); } + +#[tokio::test] +async fn test_dict_haystack() { + let sql = "select json_get(json_data, 'foo') v from dicts"; + let expected = [ + "+-----------------------+", + "| v |", + "+-----------------------+", + "| {object={\"bar\": [0]}} |", + "| {null=} |", + "| {null=} |", + "+-----------------------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_haystack_needle() { + let sql = "select json_get(json_get(json_data, str_key1), str_key2) v from dicts"; + let expected = [ + "+-------------+", + "| v |", + "+-------------+", + "| {array=[0]} |", + "| {null=} |", + "| {null=} |", + "+-------------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_length() { + let sql = "select json_length(json_data) v from dicts"; + let expected = ["+---+", "| v |", "+---+", "| 1 |", "| 1 |", "| 2 |", "+---+"]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_contains() { + let sql = "select json_contains(json_data, str_key2) v from dicts"; + let expected = [ + "+-------+", + "| v |", + "+-------+", + "| false |", + "| false |", + "| true |", + "+-------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_dict_get_int() { + let sql = "select json_get_int(json_data, str_key2) v from dicts"; + let expected = ["+---+", "| v |", "+---+", "| |", "| |", "| 1 |", "+---+"]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index 7dd91c3..f94b156 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,8 +1,10 @@ #![allow(dead_code)] use std::sync::Arc; -use datafusion::arrow::array::{ArrayRef, Int64Array}; -use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::arrow::array::{ + ArrayRef, DictionaryArray, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array, +}; +use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, UInt32Type, UInt8Type}; use datafusion::arrow::util::display::{ArrayFormatter, FormatOptions}; use datafusion::arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch}; use datafusion::common::ParamValues; @@ -109,6 +111,69 @@ async fn create_test_table(large_utf8: bool) -> Result { )?; ctx.register_batch("more_nested", more_nested_batch)?; + let dict_data = [ + (r#" {"foo": {"bar": [0]}} "#, "foo", "bar", 0), + (r#" {"bar": "snap"} "#, "foo", "spam", 0), + (r#" {"spam": 1, "snap": 2} "#, "foo", "spam", 0), + ]; + let dict_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new( + "json_data", + DataType::Dictionary(DataType::UInt32.into(), DataType::Utf8.into()), + false, + ), + Field::new( + "str_key1", + DataType::Dictionary(DataType::UInt8.into(), DataType::LargeUtf8.into()), + false, + ), + Field::new( + "str_key2", + DataType::Dictionary(DataType::UInt8.into(), DataType::Utf8View.into()), + false, + ), + Field::new( + "int_key", + DataType::Dictionary(DataType::Int64.into(), DataType::UInt64.into()), + false, + ), + ])), + vec![ + Arc::new(DictionaryArray::::new( + UInt32Array::from_iter_values(dict_data.iter().enumerate().map(|(id, _)| id as u32)), + Arc::new(StringArray::from( + dict_data.iter().map(|(json, _, _, _)| *json).collect::>(), + )), + )), + Arc::new(DictionaryArray::::new( + UInt8Array::from_iter_values(dict_data.iter().enumerate().map(|(id, _)| id as u8)), + Arc::new(LargeStringArray::from( + dict_data + .iter() + .map(|(_, str_key1, _, _)| *str_key1) + .collect::>(), + )), + )), + Arc::new(DictionaryArray::::new( + UInt8Array::from_iter_values(dict_data.iter().enumerate().map(|(id, _)| id as u8)), + Arc::new(StringViewArray::from( + dict_data + .iter() + .map(|(_, _, str_key2, _)| *str_key2) + .collect::>(), + )), + )), + Arc::new(DictionaryArray::::new( + Int64Array::from_iter_values(dict_data.iter().enumerate().map(|(id, _)| id as i64)), + Arc::new(UInt64Array::from_iter_values( + dict_data.iter().map(|(_, _, _, int_key)| *int_key as u64), + )), + )), + ], + )?; + ctx.register_batch("dicts", dict_batch)?; + Ok(ctx) }