diff --git a/src/json_object_keys.rs b/src/json_object_keys.rs new file mode 100644 index 0000000..e89c0bd --- /dev/null +++ b/src/json_object_keys.rs @@ -0,0 +1,127 @@ +use std::any::Any; +use std::sync::Arc; + +use datafusion::arrow::array::{ArrayRef, ListArray, ListBuilder, StringBuilder}; +use datafusion::arrow::datatypes::{DataType, Field}; +use datafusion::common::{Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::Peek; + +use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath}; +use crate::common_macros::make_udf_function; + +make_udf_function!( + JsonObjectKeys, + json_object_keys, + json_data path, + r#"Get the keys of a JSON object as an array."# +); + +#[derive(Debug)] +pub(super) struct JsonObjectKeys { + signature: Signature, + aliases: [String; 2], +} + +impl Default for JsonObjectKeys { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: ["json_object_keys".to_string(), "json_keys".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonObjectKeys { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + return_type_check( + arg_types, + self.name(), + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + ) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + invoke::>( + args, + jiter_json_object_keys, + |w| Ok(Arc::new(w.0) as ArrayRef), + keys_to_scalar, + true, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Wrapper for a `ListArray` that allows us to implement `FromIterator>>` as required. +#[derive(Debug)] +struct ListArrayWrapper(ListArray); + +impl FromIterator>> for ListArrayWrapper { + fn from_iter>>>(iter: I) -> Self { + let values_builder = StringBuilder::new(); + let mut builder = ListBuilder::new(values_builder); + for opt_keys in iter { + if let Some(keys) = opt_keys { + for value in keys { + builder.values().append_value(value); + } + builder.append(true); + } else { + builder.append(false); + } + } + Self(builder.finish()) + } +} + +fn keys_to_scalar(opt_keys: Option>) -> ScalarValue { + let values_builder = StringBuilder::new(); + let mut builder = ListBuilder::new(values_builder); + if let Some(keys) = opt_keys { + for value in keys { + builder.values().append_value(value); + } + builder.append(true); + } else { + builder.append(false); + } + let array = builder.finish(); + ScalarValue::List(Arc::new(array)) +} + +fn jiter_json_object_keys(opt_json: Option<&str>, path: &[JsonPath]) -> Result, GetError> { + if let Some((mut jiter, peek)) = jiter_json_find(opt_json, path) { + match peek { + Peek::Object => { + let mut opt_key = jiter.known_object()?; + + let mut keys = Vec::new(); + while let Some(key) = opt_key { + keys.push(key.to_string()); + jiter.next_skip()?; + opt_key = jiter.next_key()?; + } + Ok(keys) + } + _ => get_err!(), + } + } else { + get_err!() + } +} diff --git a/src/lib.rs b/src/lib.rs index 692478e..cb0f25a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ mod json_get_int; mod json_get_json; mod json_get_str; mod json_length; +mod json_object_keys; mod rewrite; pub use common_union::{JsonUnionEncoder, JsonUnionValue}; @@ -31,6 +32,7 @@ pub mod functions { pub use crate::json_get_json::json_get_json; pub use crate::json_get_str::json_get_str; pub use crate::json_length::json_length; + pub use crate::json_object_keys::json_object_keys; } pub mod udfs { @@ -43,6 +45,7 @@ pub mod udfs { pub use crate::json_get_json::json_get_json_udf; pub use crate::json_get_str::json_get_str_udf; pub use crate::json_length::json_length_udf; + pub use crate::json_object_keys::json_object_keys_udf; } /// Register all JSON UDFs, and [`rewrite::JsonFunctionRewriter`] with the provided [`FunctionRegistry`]. @@ -65,6 +68,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { json_get_str::json_get_str_udf(), json_contains::json_contains_udf(), json_length::json_length_udf(), + json_object_keys::json_object_keys_udf(), ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/tests/main.rs b/tests/main.rs index 46cfacf..548d67f 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1432,3 +1432,100 @@ async fn test_dict_filter_contains() { assert_batches_eq!(expected, &batches); } + +#[tokio::test] +async fn test_json_object_keys() { + let expected = [ + "+----------------------------------+", + "| json_object_keys(test.json_data) |", + "+----------------------------------+", + "| [foo] |", + "| [foo] |", + "| [foo] |", + "| [foo] |", + "| [bar] |", + "| |", + "| |", + "+----------------------------------+", + ]; + + let sql = "select json_object_keys(json_data) from test"; + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); + + let sql = "select json_object_keys(json_data) from test"; + let batches = run_query_dict(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); + + let sql = "select json_object_keys(json_data) from test"; + let batches = run_query_large(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_object_keys_many() { + let expected = [ + "+-----------------------+", + "| v |", + "+-----------------------+", + "| [foo, bar, spam, ham] |", + "+-----------------------+", + ]; + + let sql = r#"select json_object_keys('{"foo": 1, "bar": 2.2, "spam": true, "ham": []}') as v"#; + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_object_keys_nested() { + let json = r#"'{"foo": [{"bar": {"spam": true, "ham": []}}]}'"#; + + let sql = format!("select json_object_keys({json}) as v"); + let batches = run_query(&sql).await.unwrap(); + #[rustfmt::skip] + let expected = [ + "+-------+", + "| v |", + "+-------+", + "| [foo] |", + "+-------+", + ]; + assert_batches_eq!(expected, &batches); + + let sql = format!("select json_object_keys({json}, 'foo') as v"); + let batches = run_query(&sql).await.unwrap(); + #[rustfmt::skip] + let expected = [ + "+---+", + "| v |", + "+---+", + "| |", + "+---+", + ]; + assert_batches_eq!(expected, &batches); + + let sql = format!("select json_object_keys({json}, 'foo', 0) as v"); + let batches = run_query(&sql).await.unwrap(); + #[rustfmt::skip] + let expected = [ + "+-------+", + "| v |", + "+-------+", + "| [bar] |", + "+-------+", + ]; + assert_batches_eq!(expected, &batches); + + let sql = format!("select json_object_keys({json}, 'foo', 0, 'bar') as v"); + let batches = run_query(&sql).await.unwrap(); + #[rustfmt::skip] + let expected = [ + "+-------------+", + "| v |", + "+-------------+", + "| [spam, ham] |", + "+-------------+", + ]; + assert_batches_eq!(expected, &batches); +}