diff --git a/src/json_get_int.rs b/src/json_get_int.rs index 6680aa9..094ea1c 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -1,7 +1,7 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{as_string_array, Int64Array}; +use arrow::array::{as_string_array, Int64Array, StringArray}; use arrow_schema::DataType; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::{exec_err, Result as DataFusionResult, ScalarValue}; @@ -51,18 +51,34 @@ impl ScalarUDFImpl for JsonGetInt { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - let path = JsonPath::extract_args(args); - match &args[0] { - ColumnarValue::Array(array) => { - let array = as_string_array(array) - .iter() - .map(|opt_json| jiter_json_get_int(opt_json, &path).ok()) - .collect::(); - - Ok(ColumnarValue::from(Arc::new(array) as ArrayRef)) + ColumnarValue::Array(json_array) => { + let result_array = match &args[1] { + ColumnarValue::Array(a) => { + if let Some(str_path_array) = a.as_any().downcast_ref::() { + let paths = str_path_array.iter().map(|opt_key| opt_key.map(|s| JsonPath::Key(s))); + zip_apply(json_array, paths) + } else if let Some(int_path_array) = a.as_any().downcast_ref::() { + let paths = int_path_array + .iter() + .map(|opt_index| opt_index.map(|i| JsonPath::Index(i as usize))); + zip_apply(json_array, paths) + } else { + return exec_err!("unexpected second argument type, expected string or int array"); + } + } + _ => { + let path = JsonPath::extract_args(args); + as_string_array(json_array) + .iter() + .map(|opt_json| jiter_json_get_int(opt_json, &path).ok()) + .collect::() + } + }; + Ok(ColumnarValue::from(Arc::new(result_array) as ArrayRef)) } ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { + let path = JsonPath::extract_args(args); let v = jiter_json_get_int(s.as_ref().map(|s| s.as_str()), &path).ok(); Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) } @@ -77,6 +93,20 @@ impl ScalarUDFImpl for JsonGetInt { } } +fn zip_apply<'a, T: Iterator>>>(json_array: &ArrayRef, paths: T) -> Int64Array { + as_string_array(json_array) + .iter() + .zip(paths) + .map(|(opt_json, opt_path)| { + if let Some(path) = opt_path { + jiter_json_get_int(opt_json, &[path]).ok() + } else { + None + } + }) + .collect::() +} + fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath]) -> Result { if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { match peek { diff --git a/tests/test_json_get.rs b/tests/test_json_get.rs index cf45196..31303ea 100644 --- a/tests/test_json_get.rs +++ b/tests/test_json_get.rs @@ -84,9 +84,8 @@ async fn test_json_get_str() { #[tokio::test] async fn test_json_get_str_equals() { - let batches = run_query("select name, json_get_str(json_data, 'foo')='abc' from test") - .await - .unwrap(); + let sql = "select name, json_get_str(json_data, 'foo')='abc' from test"; + let batches = run_query(sql).await.unwrap(); let expected = [ "+------------------+--------------------------------------------------------+", @@ -167,6 +166,32 @@ async fn test_json_get_cast_int_path() { assert_eq!(display_val(batches).await, (DataType::Int64, "73".to_string())); } +#[tokio::test] +async fn test_json_get_int_lookup() { + let sql = "select str_key, json_data from other where json_get_int(json_data, str_key) is not null"; + let batches = run_query(sql).await.unwrap(); + let expected = [ + "+---------+---------------+", + "| str_key | json_data |", + "+---------+---------------+", + "| foo | {\"foo\": 42} |", + "+---------+---------------+", + ]; + assert_batches_eq!(expected, &batches); + + // lookup by int + let sql = "select int_key, json_data from other where json_get_int(json_data, int_key) is not null"; + let batches = run_query(sql).await.unwrap(); + let expected = [ + "+---------+-----------+", + "| int_key | json_data |", + "+---------+-----------+", + "| 0 | [42] |", + "+---------+-----------+", + ]; + assert_batches_eq!(expected, &batches); +} + #[tokio::test] async fn test_json_get_float() { let batches = run_query(r#"select json_get_float('[1.5]', 0) as v"#).await.unwrap(); diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index ee39a02..34e505b 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] +use arrow::array::Int64Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::util::display::{ArrayFormatter, FormatOptions}; use arrow::{array::StringArray, record_batch::RecordBatch}; @@ -9,12 +10,10 @@ use datafusion::execution::context::SessionContext; use datafusion_functions_json::register_all; async fn create_test_table() -> Result { - let schema = Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("json_data", DataType::Utf8, false), - ])); + let mut ctx = SessionContext::new(); + register_all(&mut ctx)?; - let data = [ + let test_data = [ ("object_foo", r#" {"foo": "abc"} "#), ("object_foo_array", r#" {"foo": [1]} "#), ("object_foo_obj", r#" {"foo": {}} "#), @@ -23,22 +22,48 @@ async fn create_test_table() -> Result { ("list_foo", r#" ["foo"] "#), ("invalid_json", "is not json"), ]; + let test_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("json_data", DataType::Utf8, false), + ])), + vec![ + Arc::new(StringArray::from( + test_data.iter().map(|(name, _)| *name).collect::>(), + )), + Arc::new(StringArray::from( + test_data.iter().map(|(_, json)| *json).collect::>(), + )), + ], + )?; + ctx.register_batch("test", test_batch)?; - let batch = RecordBatch::try_new( - schema, + let other_data = [ + (r#" {"foo": 42} "#, "foo", 0), + (r#" {"foo": 42} "#, "bar", 1), + (r#" [42] "#, "foo", 0), + (r#" [42] "#, "bar", 1), + ]; + let other_batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("json_data", DataType::Utf8, false), + Field::new("str_key", DataType::Utf8, false), + Field::new("int_key", DataType::Int64, false), + ])), vec![ Arc::new(StringArray::from( - data.iter().map(|(name, _)| *name).collect::>(), + other_data.iter().map(|(json, _, _)| *json).collect::>(), )), Arc::new(StringArray::from( - data.iter().map(|(_, json)| *json).collect::>(), + other_data.iter().map(|(_, str_key, _)| *str_key).collect::>(), + )), + Arc::new(Int64Array::from( + other_data.iter().map(|(_, _, int_key)| *int_key).collect::>(), )), ], )?; + ctx.register_batch("other", other_batch)?; - let mut ctx = SessionContext::new(); - register_all(&mut ctx)?; - ctx.register_batch("test", batch)?; Ok(ctx) }