diff --git a/Cargo.toml b/Cargo.toml index 6548081..6ab9905 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ paste = "1" log = "0.4" [dev-dependencies] +datafusion = { version = "44", default-features = false, features = ["nested_expressions"] } codspeed-criterion-compat = "2.6" criterion = "0.5.1" clap = "4" diff --git a/src/common.rs b/src/common.rs index fc5b115..fdba8d2 100644 --- a/src/common.rs +++ b/src/common.rs @@ -95,17 +95,46 @@ impl From for JsonPath<'_> { } } -impl<'s> JsonPath<'s> { - pub fn extract_path(args: &'s [ColumnarValue]) -> Vec { - args[1..] +enum JsonPathArgs<'a> { + Array(&'a ArrayRef), + Scalars(Vec>), +} + +impl<'s> JsonPathArgs<'s> { + fn extract_path(path_args: &'s [ColumnarValue]) -> DataFusionResult { + // If there is a single argument as an array, we know how to handle it + if let Some((ColumnarValue::Array(array), &[])) = path_args.split_first() { + return Ok(Self::Array(array)); + } + + path_args .iter() - .map(|arg| match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s))) => Self::Key(s), - ColumnarValue::Scalar(ScalarValue::UInt64(Some(i))) => (*i).into(), - ColumnarValue::Scalar(ScalarValue::Int64(Some(i))) => (*i).into(), - _ => Self::None, + .enumerate() + .map(|(pos, arg)| match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s))) => { + Ok(JsonPath::Key(s)) + } + ColumnarValue::Scalar(ScalarValue::UInt64(Some(i))) => Ok((*i).into()), + ColumnarValue::Scalar(ScalarValue::Int64(Some(i))) => Ok((*i).into()), + ColumnarValue::Scalar( + ScalarValue::Null + | ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::UInt64(None) + | ScalarValue::Int64(None), + ) => Ok(JsonPath::None), + ColumnarValue::Array(_) => { + // if there was a single arg, which is an array, handled above in the + // split_first case. So this is multiple args of which one is an array + exec_err!("More than 1 path element is not supported when querying JSON using an array.") + } + ColumnarValue::Scalar(arg) => exec_err!( + "Unexpected argument type at position {}, expected string or int, got {arg:?}.", + pos + 1 + ), }) - .collect() + .collect::>() + .map(JsonPathArgs::Scalars) } } @@ -116,154 +145,173 @@ pub fn invoke> + 'static, I>( to_scalar: impl Fn(Option) -> ScalarValue, return_dict: bool, ) -> DataFusionResult { - let Some(first_arg) = args.first() else { - // I think this can't happen, but I assumed the same about args[1] and I was wrong, so better to be safe + let Some((json_arg, path_args)) = args.split_first() else { return exec_err!("expected at least one argument"); }; - match first_arg { - ColumnarValue::Array(json_array) => { - let array = match args.get(1) { - Some(ColumnarValue::Array(a)) => { - if args.len() > 2 { - // TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23 - exec_err!("More than 1 path element is not supported when querying JSON using an array.") - } else { - invoke_array(json_array, a, to_array, jiter_find, return_dict) - } - } - Some(ColumnarValue::Scalar(_)) => scalar_apply( - json_array, - &JsonPath::extract_path(args), - to_array, - jiter_find, - return_dict, - ), - None => scalar_apply(json_array, &[], to_array, jiter_find, return_dict), - }; - array.map(ColumnarValue::from) + + let path = JsonPathArgs::extract_path(path_args)?; + match (json_arg, path) { + (ColumnarValue::Array(json_array), JsonPathArgs::Array(path_array)) => { + invoke_array_array(json_array, path_array, to_array, jiter_find, return_dict).map(ColumnarValue::Array) + } + (ColumnarValue::Array(json_array), JsonPathArgs::Scalars(path)) => { + invoke_array_scalars(json_array, &path, to_array, jiter_find, return_dict).map(ColumnarValue::Array) + } + (ColumnarValue::Scalar(s), JsonPathArgs::Array(path_array)) => { + invoke_scalar_array(s, path_array, jiter_find, to_array) + } + (ColumnarValue::Scalar(s), JsonPathArgs::Scalars(path)) => { + invoke_scalar_scalars(s, &path, jiter_find, to_scalar) } - ColumnarValue::Scalar(s) => invoke_scalar(s, args, jiter_find, to_scalar), } } -fn invoke_array> + 'static, I>( +fn invoke_array_array> + 'static, I>( json_array: &ArrayRef, - needle_array: &ArrayRef, + path_array: &ArrayRef, to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, return_dict: bool, ) -> DataFusionResult { downcast_dictionary_array!( - needle_array => match needle_array.values().data_type() { - DataType::Utf8 => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, true, return_dict), - DataType::LargeUtf8 => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, true, return_dict), - DataType::Utf8View => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, true, return_dict), - DataType::Int64 => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, false, return_dict), - DataType::UInt64 => zip_apply(json_array, needle_array.downcast_dict::().unwrap(), to_array, jiter_find, false, return_dict), - other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other), - }, - DataType::Utf8 => zip_apply(json_array, needle_array.as_string::(), to_array, jiter_find, true, return_dict), - DataType::LargeUtf8 => zip_apply(json_array, needle_array.as_string::(), to_array, jiter_find, true, return_dict), - DataType::Utf8View => zip_apply(json_array, needle_array.as_string_view(), to_array, jiter_find, true, return_dict), - DataType::Int64 => zip_apply(json_array, needle_array.as_primitive::(), to_array, jiter_find, false, return_dict), - DataType::UInt64 => zip_apply(json_array, needle_array.as_primitive::(), to_array, jiter_find, false, return_dict), - other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other) + json_array => { + let values = invoke_array_array(json_array.values(), path_array, to_array, jiter_find, return_dict)?; + post_process_dict(json_array, values, return_dict) + } + DataType::Utf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), + DataType::LargeUtf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), + DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find), + other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup_array(path_array.data_type())) { + zip_apply(string_array.iter(), path_array, to_array, jiter_find) + } else { + exec_err!("unexpected json array type {:?}", other) + } ) } -fn zip_apply<'a, P: Into>, C: FromIterator> + 'static, I>( +fn invoke_array_scalars>, I>( json_array: &ArrayRef, - path_array: impl ArrayAccessor, + path: &[JsonPath], to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, - object_lookup: bool, return_dict: bool, ) -> DataFusionResult { + fn inner<'j, C: FromIterator>, I>( + json_iter: impl IntoIterator>, + path: &[JsonPath], + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, + ) -> C { + json_iter + .into_iter() + .map(|opt_json| jiter_find(opt_json, path).ok()) + .collect::() + } + let c = downcast_dictionary_array!( json_array => { - let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup, false)?; + let values = invoke_array_scalars(json_array.values(), path, to_array, jiter_find, false)?; return post_process_dict(json_array, values, return_dict); } - DataType::Utf8 => zip_apply_iter(json_array.as_string::().iter(), path_array, jiter_find), - DataType::LargeUtf8 => zip_apply_iter(json_array.as_string::().iter(), path_array, jiter_find), - DataType::Utf8View => zip_apply_iter(json_array.as_string_view().iter(), path_array, jiter_find), - other => if let Some(string_array) = nested_json_array(json_array, object_lookup) { - zip_apply_iter(string_array.iter(), path_array, jiter_find) + DataType::Utf8 => inner(json_array.as_string::(), path, jiter_find), + DataType::LargeUtf8 => inner(json_array.as_string::(), path, jiter_find), + DataType::Utf8View => inner(json_array.as_string_view(), path, jiter_find), + other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) { + inner(string_array, path, jiter_find) } else { return exec_err!("unexpected json array type {:?}", other); } ); - to_array(c) } -#[allow(clippy::needless_pass_by_value)] // ArrayAccessor is implemented on references -fn zip_apply_iter<'a, 'j, P: Into>, C: FromIterator> + 'static, I>( - json_iter: impl Iterator>, - path_array: impl ArrayAccessor, +fn invoke_scalar_array> + 'static, I>( + scalar: &ScalarValue, + path_array: &ArrayRef, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, -) -> C { - json_iter - .enumerate() - .map(|(i, opt_json)| { - if path_array.is_null(i) { - None - } else { - let path = path_array.value(i).into(); - jiter_find(opt_json, &[path]).ok() - } - }) - .collect::() + to_array: impl Fn(C) -> DataFusionResult, +) -> DataFusionResult { + let s = extract_json_scalar(scalar)?; + // TODO: possible optimization here if path_array is a dictionary; can apply against the + // dictionary values directly for less work + zip_apply( + std::iter::repeat(s).take(path_array.len()), + path_array, + to_array, + jiter_find, + ) + .map(ColumnarValue::Array) } -fn invoke_scalar( +fn invoke_scalar_scalars( scalar: &ScalarValue, - args: &[ColumnarValue], + path: &[JsonPath], jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, to_scalar: impl Fn(Option) -> ScalarValue, ) -> DataFusionResult { - match scalar { - ScalarValue::Dictionary(_, b) => invoke_scalar(b.as_ref(), args, jiter_find, to_scalar), - ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => { - let path = JsonPath::extract_path(args); - let v = jiter_find(s.as_ref().map(String::as_str), &path).ok(); - Ok(ColumnarValue::Scalar(to_scalar(v))) - } - ScalarValue::Union(type_id_value, union_fields, _) => { - let opt_json = json_from_union_scalar(type_id_value.as_ref(), union_fields); - let v = jiter_find(opt_json, &JsonPath::extract_path(args)).ok(); - Ok(ColumnarValue::Scalar(to_scalar(v))) - } - _ => { - exec_err!("unexpected first argument type, expected string or JSON union") - } - } + let s = extract_json_scalar(scalar)?; + let v = jiter_find(s, path).ok(); + Ok(ColumnarValue::Scalar(to_scalar(v))) } -fn scalar_apply>, I>( - json_array: &ArrayRef, - path: &[JsonPath], +fn zip_apply<'a, C: FromIterator> + 'static, I>( + json_array: impl IntoIterator>, + path_array: &ArrayRef, to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, - return_dict: bool, ) -> DataFusionResult { + #[allow(clippy::needless_pass_by_value)] // ArrayAccessor is implemented on references + fn inner<'a, 'j, P: Into>, C: FromIterator> + 'static, I>( + json_iter: impl IntoIterator>, + path_array: impl ArrayAccessor, + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, + ) -> C { + json_iter + .into_iter() + .enumerate() + .map(|(i, opt_json)| { + if path_array.is_null(i) { + None + } else { + let path = path_array.value(i).into(); + jiter_find(opt_json, &[path]).ok() + } + }) + .collect::() + } + let c = downcast_dictionary_array!( - json_array => { - let values = scalar_apply(json_array.values(), path, to_array, jiter_find, false)?; - return post_process_dict(json_array, values, return_dict); - } - DataType::Utf8 => scalar_apply_iter(json_array.as_string::().iter(), path, jiter_find), - DataType::LargeUtf8 => scalar_apply_iter(json_array.as_string::().iter(), path, jiter_find), - DataType::Utf8View => scalar_apply_iter(json_array.as_string_view().iter(), path, jiter_find), - other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) { - scalar_apply_iter(string_array.iter(), path, jiter_find) - } else { - return exec_err!("unexpected json array type {:?}", other); - } + path_array => match path_array.values().data_type() { + DataType::Utf8 => inner(json_array, path_array.downcast_dict::().unwrap(), jiter_find), + DataType::LargeUtf8 => inner(json_array, path_array.downcast_dict::().unwrap(), jiter_find), + DataType::Utf8View => inner(json_array, path_array.downcast_dict::().unwrap(), jiter_find), + DataType::Int64 => inner(json_array, path_array.downcast_dict::().unwrap(), jiter_find), + DataType::UInt64 => inner(json_array, path_array.downcast_dict::().unwrap(), jiter_find), + other => return exec_err!("unexpected second argument type, expected string or int array, got {:?}", other), + }, + DataType::Utf8 => inner(json_array, path_array.as_string::(), jiter_find), + DataType::LargeUtf8 => inner(json_array, path_array.as_string::(), jiter_find), + DataType::Utf8View => inner(json_array, path_array.as_string_view(), jiter_find), + DataType::Int64 => inner(json_array, path_array.as_primitive::(), jiter_find), + DataType::UInt64 => inner(json_array, path_array.as_primitive::(), jiter_find), + other => return exec_err!("unexpected second argument type, expected string or int array, got {:?}", other) ); + to_array(c) } +fn extract_json_scalar(scalar: &ScalarValue) -> DataFusionResult> { + match scalar { + ScalarValue::Dictionary(_, b) => extract_json_scalar(b.as_ref()), + ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => Ok(s.as_deref()), + ScalarValue::Union(type_id_value, union_fields, _) => { + Ok(json_from_union_scalar(type_id_value.as_ref(), union_fields)) + } + _ => { + exec_err!("unexpected first argument type, expected string or JSON union") + } + } +} + /// Take a dictionary array of JSON data and an array of result values and combine them. fn post_process_dict( dict_array: &DictionaryArray, @@ -295,12 +343,12 @@ fn is_object_lookup(path: &[JsonPath]) -> bool { } } -fn scalar_apply_iter<'j, C: FromIterator>, I>( - json_iter: impl Iterator>, - path: &[JsonPath], - jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, -) -> C { - json_iter.map(|opt_json| jiter_find(opt_json, path).ok()).collect::() +fn is_object_lookup_array(data_type: &DataType) -> bool { + match data_type { + DataType::Dictionary(_, value_type) => is_object_lookup_array(value_type), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => true, + _ => false, + } } pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> { diff --git a/tests/main.rs b/tests/main.rs index 567f8ce..54a3f9b 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1596,3 +1596,79 @@ async fn test_json_object_keys_nested() { ]; assert_batches_eq!(expected, &batches); } + +#[tokio::test] +async fn test_lookup_literal_column_matrix() { + let sql = r#" +WITH attr_names AS ( + -- this is deliberately a different length to json_columns + SELECT unnest(['a', 'b', 'c']) as attr_name +), json_columns AS ( + SELECT unnest(['{"a": 1}', '{"b": 2}']) as json_column +) +SELECT + attr_name, + json_column, + 'a' = attr_name, + json_get('{"a": 1}', attr_name), -- literal lookup with column + json_get('{"a": 1}', 'a'), -- literal lookup with literal + json_get(json_column, attr_name), -- column lookup with column + json_get(json_column, 'a') -- column lookup with literal +FROM attr_names, json_columns +"#; + + let expected = [ + "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,Utf8(\"a\")) |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} |", + "| a | {\"b\": 2} | true | {int=1} | {int=1} | {null=} | {null=} |", + "| b | {\"a\": 1} | false | {null=} | {int=1} | {null=} | {int=1} |", + "| b | {\"b\": 2} | false | {null=} | {int=1} | {int=2} | {null=} |", + "| c | {\"a\": 1} | false | {null=} | {int=1} | {null=} | {int=1} |", + "| c | {\"b\": 2} | false | {null=} | {int=1} | {null=} | {null=} |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_lookup_literal_column_matrix_dictionaries() { + let sql = r#" +WITH attr_names AS ( + -- this is deliberately a different length to json_columns + SELECT arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name +), json_columns AS ( + SELECT arrow_cast(unnest(['{"a": 1}', '{"b": 2}']), 'Dictionary(Int32, Utf8)') as json_column +) +SELECT + attr_name, + json_column, + 'a' = attr_name, + json_get('{"a": 1}', attr_name), -- literal lookup with column + json_get('{"a": 1}', 'a'), -- literal lookup with literal + json_get(json_column, attr_name), -- column lookup with column + json_get(json_column, 'a') -- column lookup with literal +FROM attr_names, json_columns +"#; + + // NB as compared to the non-dictionary case, we null out the dictionary keys if the return + // value is a dict, which is why we get true nulls instead of {null=} + let expected = [ + "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,Utf8(\"a\")) |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} |", + "| a | {\"b\": 2} | true | {int=1} | {int=1} | | |", + "| b | {\"a\": 1} | false | {null=} | {int=1} | | {int=1} |", + "| b | {\"b\": 2} | false | {null=} | {int=1} | {int=2} | |", + "| c | {\"a\": 1} | false | {null=} | {int=1} | | {int=1} |", + "| c | {\"b\": 2} | false | {null=} | {int=1} | | |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + ]; + + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +}