From 0a4493cb575ee41eb793fd8c44f3cc85b3334490 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 5 May 2024 09:42:50 +0100 Subject: [PATCH] support `LargeUtf8` (#7) --- Cargo.toml | 2 +- src/common.rs | 58 ++++++++++++++++++++++++-------- tests/main.rs | 83 +++++++++++++++++++++++++++++++++++++++++++++- tests/utils/mod.rs | 41 +++++++++++++++++------ 4 files changed, 158 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2b9f9d6..297b4ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "datafusion-functions-json" -version = "0.1.0" +version = "0.1.1" edition = "2021" description = "JSON functions for DataFusion" readme = "README.md" diff --git a/src/common.rs b/src/common.rs index b93297c..27d41e0 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,6 +1,6 @@ use std::str::Utf8Error; -use arrow::array::{as_string_array, Array, ArrayRef, Int64Array, StringArray, UInt64Array}; +use arrow::array::{Array, ArrayRef, Int64Array, LargeStringArray, StringArray, UInt64Array}; use arrow_schema::DataType; use datafusion_common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -10,11 +10,11 @@ pub fn check_args(args: &[DataType], fn_name: &str) -> DataFusionResult<()> { let Some(first) = args.first() else { return plan_err!("The `{fn_name}` function requires one or more arguments."); }; - if !matches!(first, DataType::Utf8) { + if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) { return plan_err!("Unexpected argument type to `{fn_name}` at position 1, expected a string."); } args[1..].iter().enumerate().try_for_each(|(index, arg)| match arg { - DataType::Utf8 | DataType::UInt64 | DataType::Int64 => Ok(()), + DataType::Utf8 | DataType::LargeUtf8 | DataType::UInt64 | DataType::Int64 => Ok(()), _ => plan_err!( "Unexpected argument type to `{fn_name}` at position {}, expected string or int.", index + 2 @@ -71,6 +71,9 @@ pub fn invoke> + 'static, I>( if let Some(str_path_array) = a.as_any().downcast_ref::() { let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); zip_apply(json_array, paths, jiter_find) + } else if let Some(str_path_array) = a.as_any().downcast_ref::() { + let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key)); + zip_apply(json_array, paths, jiter_find) } else if let Some(int_path_array) = a.as_any().downcast_ref::() { let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into)); zip_apply(json_array, paths, jiter_find) @@ -81,15 +84,9 @@ pub fn invoke> + 'static, I>( return exec_err!("unexpected second argument type, expected string or int array"); } } - ColumnarValue::Scalar(_) => { - let path = JsonPath::extract_path(args); - as_string_array(json_array) - .iter() - .map(|opt_json| jiter_find(opt_json, &path).ok()) - .collect::() - } + ColumnarValue::Scalar(_) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find), }; - to_array(result_collect).map(ColumnarValue::from) + to_array(result_collect?).map(ColumnarValue::from) } ColumnarValue::Scalar(ScalarValue::Utf8(s)) => { let path = JsonPath::extract_path(args); @@ -106,9 +103,22 @@ fn zip_apply<'a, P: Iterator>>, C: FromIterator, &[JsonPath]) -> Result, +) -> DataFusionResult { + if let Some(string_array) = json_array.as_any().downcast_ref::() { + Ok(zip_apply_iter(string_array.iter(), paths, jiter_find)) + } else if let Some(large_string_array) = json_array.as_any().downcast_ref::() { + Ok(zip_apply_iter(large_string_array.iter(), paths, jiter_find)) + } else { + exec_err!("unexpected json array type") + } +} + +fn zip_apply_iter<'a, 'j, P: Iterator>>, C: FromIterator> + 'static, I>( + json_iter: impl Iterator>, + paths: P, + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, ) -> C { - as_string_array(json_array) - .iter() + json_iter .zip(paths) .map(|(opt_json, opt_path)| { if let Some(path) = opt_path { @@ -120,6 +130,28 @@ fn zip_apply<'a, P: Iterator>>, C: FromIterator() } +fn scalar_apply<'a, C: FromIterator> + 'static, I>( + json_array: &ArrayRef, + path: &[JsonPath], + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, +) -> DataFusionResult { + if let Some(string_array) = json_array.as_any().downcast_ref::() { + Ok(scalar_apply_iter(string_array.iter(), path, jiter_find)) + } else if let Some(large_string_array) = json_array.as_any().downcast_ref::() { + Ok(scalar_apply_iter(large_string_array.iter(), path, jiter_find)) + } else { + exec_err!("unexpected json array type") + } +} + +fn scalar_apply_iter<'a, 'j, C: FromIterator> + 'static, I>( + json_iter: impl Iterator>, + path: &[JsonPath], + jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, +) -> C { + json_iter.map(|opt_json| jiter_find(opt_json, path).ok()).collect::() +} + pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> { if let Some(json_str) = opt_json { let mut jiter = Jiter::new(json_str.as_bytes(), false); diff --git a/tests/main.rs b/tests/main.rs index 37de453..263c2ef 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1,8 +1,9 @@ use arrow_schema::DataType; use datafusion::assert_batches_eq; +use datafusion_common::ScalarValue; mod utils; -use utils::{display_val, run_query}; +use utils::{display_val, run_query, run_query_large, run_query_params}; #[tokio::test] async fn test_json_contains() { @@ -348,3 +349,83 @@ async fn test_json_length_object_nested() { let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::UInt64, "".to_string())); } + +#[tokio::test] +async fn test_json_contains_large() { + let expected = [ + "+----------+", + "| COUNT(*) |", + "+----------+", + "| 4 |", + "+----------+", + ]; + + let batches = run_query_large("select count(*) from test where json_contains(json_data, 'foo')") + .await + .unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_contains_large_vec() { + let expected = [ + "+----------+", + "| COUNT(*) |", + "+----------+", + "| 0 |", + "+----------+", + ]; + + let batches = run_query_large("select count(*) from test where json_contains(json_data, name)") + .await + .unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_contains_large_both() { + let expected = [ + "+----------+", + "| COUNT(*) |", + "+----------+", + "| 0 |", + "+----------+", + ]; + + let batches = run_query_large("select count(*) from test where json_contains(json_data, json_data)") + .await + .unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_contains_large_params() { + let expected = [ + "+----------+", + "| COUNT(*) |", + "+----------+", + "| 4 |", + "+----------+", + ]; + + let sql = "select count(*) from test where json_contains(json_data, 'foo')"; + let params = vec![ScalarValue::LargeUtf8(Some("foo".to_string()))]; + let batches = run_query_params(sql, false, params).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_json_contains_large_both_params() { + let expected = [ + "+----------+", + "| COUNT(*) |", + "+----------+", + "| 4 |", + "+----------+", + ]; + + let sql = "select count(*) from test where json_contains(json_data, 'foo')"; + let params = vec![ScalarValue::LargeUtf8(Some("foo".to_string()))]; + let batches = run_query_params(sql, true, params).await.unwrap(); + assert_batches_eq!(expected, &batches); +} diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index c84cdb8..f812cc6 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -1,15 +1,17 @@ #![allow(dead_code)] -use arrow::array::Int64Array; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::util::display::{ArrayFormatter, FormatOptions}; -use arrow::{array::StringArray, record_batch::RecordBatch}; -use std::sync::Arc; +use arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch}; use datafusion::error::Result; use datafusion::execution::context::SessionContext; +use datafusion_common::ParamValues; use datafusion_functions_json::register_all; -async fn create_test_table() -> Result { +async fn create_test_table(large_utf8: bool) -> Result { let mut ctx = SessionContext::new(); register_all(&mut ctx)?; @@ -22,18 +24,22 @@ async fn create_test_table() -> Result { ("list_foo", r#" ["foo"] "#), ("invalid_json", "is not json"), ]; + let json_values = test_data.iter().map(|(_, json)| *json).collect::>(); + let (json_data_type, json_array): (DataType, ArrayRef) = if large_utf8 { + (DataType::LargeUtf8, Arc::new(LargeStringArray::from(json_values))) + } else { + (DataType::Utf8, Arc::new(StringArray::from(json_values))) + }; let test_batch = RecordBatch::try_new( Arc::new(Schema::new(vec![ Field::new("name", DataType::Utf8, false), - Field::new("json_data", DataType::Utf8, false), + Field::new("json_data", json_data_type, false), ])), vec![ Arc::new(StringArray::from( test_data.iter().map(|(name, _)| *name).collect::>(), )), - Arc::new(StringArray::from( - test_data.iter().map(|(_, json)| *json).collect::>(), - )), + json_array, ], )?; ctx.register_batch("test", test_batch)?; @@ -68,9 +74,22 @@ async fn create_test_table() -> Result { } pub async fn run_query(sql: &str) -> Result> { - let ctx = create_test_table().await?; - let df = ctx.sql(sql).await?; - df.collect().await + let ctx = create_test_table(false).await?; + ctx.sql(sql).await?.collect().await +} + +pub async fn run_query_large(sql: &str) -> Result> { + let ctx = create_test_table(true).await?; + ctx.sql(sql).await?.collect().await +} + +pub async fn run_query_params( + sql: &str, + large_utf8: bool, + query_values: impl Into, +) -> Result> { + let ctx = create_test_table(large_utf8).await?; + ctx.sql(sql).await?.with_param_values(query_values)?.collect().await } pub async fn display_val(batch: Vec) -> (DataType, String) {