Skip to content

Commit

Permalink
support LargeUtf8 (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored May 5, 2024
1 parent dac4d1a commit 0a4493c
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "datafusion-functions-json"
version = "0.1.0"
version = "0.1.1"
edition = "2021"
description = "JSON functions for DataFusion"
readme = "README.md"
Expand Down
58 changes: 45 additions & 13 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::str::Utf8Error;

use arrow::array::{as_string_array, Array, ArrayRef, Int64Array, StringArray, UInt64Array};
use arrow::array::{Array, ArrayRef, Int64Array, LargeStringArray, StringArray, UInt64Array};
use arrow_schema::DataType;
use datafusion_common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion_expr::ColumnarValue;
Expand All @@ -10,11 +10,11 @@ pub fn check_args(args: &[DataType], fn_name: &str) -> DataFusionResult<()> {
let Some(first) = args.first() else {
return plan_err!("The `{fn_name}` function requires one or more arguments.");
};
if !matches!(first, DataType::Utf8) {
if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) {
return plan_err!("Unexpected argument type to `{fn_name}` at position 1, expected a string.");
}
args[1..].iter().enumerate().try_for_each(|(index, arg)| match arg {
DataType::Utf8 | DataType::UInt64 | DataType::Int64 => Ok(()),
DataType::Utf8 | DataType::LargeUtf8 | DataType::UInt64 | DataType::Int64 => Ok(()),
_ => plan_err!(
"Unexpected argument type to `{fn_name}` at position {}, expected string or int.",
index + 2
Expand Down Expand Up @@ -71,6 +71,9 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
if let Some(str_path_array) = a.as_any().downcast_ref::<StringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, jiter_find)
} else if let Some(str_path_array) = a.as_any().downcast_ref::<LargeStringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, jiter_find)
} else if let Some(int_path_array) = a.as_any().downcast_ref::<Int64Array>() {
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
zip_apply(json_array, paths, jiter_find)
Expand All @@ -81,15 +84,9 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
return exec_err!("unexpected second argument type, expected string or int array");
}
}
ColumnarValue::Scalar(_) => {
let path = JsonPath::extract_path(args);
as_string_array(json_array)
.iter()
.map(|opt_json| jiter_find(opt_json, &path).ok())
.collect::<C>()
}
ColumnarValue::Scalar(_) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find),
};
to_array(result_collect).map(ColumnarValue::from)
to_array(result_collect?).map(ColumnarValue::from)
}
ColumnarValue::Scalar(ScalarValue::Utf8(s)) => {
let path = JsonPath::extract_path(args);
Expand All @@ -106,9 +103,22 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
json_array: &ArrayRef,
paths: P,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> DataFusionResult<C> {
if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
Ok(zip_apply_iter(string_array.iter(), paths, jiter_find))
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
Ok(zip_apply_iter(large_string_array.iter(), paths, jiter_find))
} else {
exec_err!("unexpected json array type")
}
}

fn zip_apply_iter<'a, 'j, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
json_iter: impl Iterator<Item = Option<&'j str>>,
paths: P,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> C {
as_string_array(json_array)
.iter()
json_iter
.zip(paths)
.map(|(opt_json, opt_path)| {
if let Some(path) = opt_path {
Expand All @@ -120,6 +130,28 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
.collect::<C>()
}

fn scalar_apply<'a, C: FromIterator<Option<I>> + 'static, I>(
json_array: &ArrayRef,
path: &[JsonPath],
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> DataFusionResult<C> {
if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
Ok(scalar_apply_iter(string_array.iter(), path, jiter_find))
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
Ok(scalar_apply_iter(large_string_array.iter(), path, jiter_find))
} else {
exec_err!("unexpected json array type")
}
}

fn scalar_apply_iter<'a, 'j, C: FromIterator<Option<I>> + 'static, I>(
json_iter: impl Iterator<Item = Option<&'j str>>,
path: &[JsonPath],
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> C {
json_iter.map(|opt_json| jiter_find(opt_json, path).ok()).collect::<C>()
}

pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> {
if let Some(json_str) = opt_json {
let mut jiter = Jiter::new(json_str.as_bytes(), false);
Expand Down
83 changes: 82 additions & 1 deletion tests/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use arrow_schema::DataType;
use datafusion::assert_batches_eq;
use datafusion_common::ScalarValue;

mod utils;
use utils::{display_val, run_query};
use utils::{display_val, run_query, run_query_large, run_query_params};

#[tokio::test]
async fn test_json_contains() {
Expand Down Expand Up @@ -348,3 +349,83 @@ async fn test_json_length_object_nested() {
let batches = run_query(sql).await.unwrap();
assert_eq!(display_val(batches).await, (DataType::UInt64, "".to_string()));
}

#[tokio::test]
async fn test_json_contains_large() {
let expected = [
"+----------+",
"| COUNT(*) |",
"+----------+",
"| 4 |",
"+----------+",
];

let batches = run_query_large("select count(*) from test where json_contains(json_data, 'foo')")
.await
.unwrap();
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_contains_large_vec() {
let expected = [
"+----------+",
"| COUNT(*) |",
"+----------+",
"| 0 |",
"+----------+",
];

let batches = run_query_large("select count(*) from test where json_contains(json_data, name)")
.await
.unwrap();
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_contains_large_both() {
let expected = [
"+----------+",
"| COUNT(*) |",
"+----------+",
"| 0 |",
"+----------+",
];

let batches = run_query_large("select count(*) from test where json_contains(json_data, json_data)")
.await
.unwrap();
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_contains_large_params() {
let expected = [
"+----------+",
"| COUNT(*) |",
"+----------+",
"| 4 |",
"+----------+",
];

let sql = "select count(*) from test where json_contains(json_data, 'foo')";
let params = vec![ScalarValue::LargeUtf8(Some("foo".to_string()))];
let batches = run_query_params(sql, false, params).await.unwrap();
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_contains_large_both_params() {
let expected = [
"+----------+",
"| COUNT(*) |",
"+----------+",
"| 4 |",
"+----------+",
];

let sql = "select count(*) from test where json_contains(json_data, 'foo')";
let params = vec![ScalarValue::LargeUtf8(Some("foo".to_string()))];
let batches = run_query_params(sql, true, params).await.unwrap();
assert_batches_eq!(expected, &batches);
}
41 changes: 30 additions & 11 deletions tests/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
#![allow(dead_code)]
use arrow::array::Int64Array;
use std::sync::Arc;

use arrow::array::{ArrayRef, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::util::display::{ArrayFormatter, FormatOptions};
use arrow::{array::StringArray, record_batch::RecordBatch};
use std::sync::Arc;
use arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch};

use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion_common::ParamValues;
use datafusion_functions_json::register_all;

async fn create_test_table() -> Result<SessionContext> {
async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
let mut ctx = SessionContext::new();
register_all(&mut ctx)?;

Expand All @@ -22,18 +24,22 @@ async fn create_test_table() -> Result<SessionContext> {
("list_foo", r#" ["foo"] "#),
("invalid_json", "is not json"),
];
let json_values = test_data.iter().map(|(_, json)| *json).collect::<Vec<_>>();
let (json_data_type, json_array): (DataType, ArrayRef) = if large_utf8 {
(DataType::LargeUtf8, Arc::new(LargeStringArray::from(json_values)))
} else {
(DataType::Utf8, Arc::new(StringArray::from(json_values)))
};
let test_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("json_data", DataType::Utf8, false),
Field::new("json_data", json_data_type, false),
])),
vec![
Arc::new(StringArray::from(
test_data.iter().map(|(name, _)| *name).collect::<Vec<_>>(),
)),
Arc::new(StringArray::from(
test_data.iter().map(|(_, json)| *json).collect::<Vec<_>>(),
)),
json_array,
],
)?;
ctx.register_batch("test", test_batch)?;
Expand Down Expand Up @@ -68,9 +74,22 @@ async fn create_test_table() -> Result<SessionContext> {
}

pub async fn run_query(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table().await?;
let df = ctx.sql(sql).await?;
df.collect().await
let ctx = create_test_table(false).await?;
ctx.sql(sql).await?.collect().await
}

pub async fn run_query_large(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(true).await?;
ctx.sql(sql).await?.collect().await
}

pub async fn run_query_params(
sql: &str,
large_utf8: bool,
query_values: impl Into<ParamValues>,
) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(large_utf8).await?;
ctx.sql(sql).await?.with_param_values(query_values)?.collect().await
}

pub async fn display_val(batch: Vec<RecordBatch>) -> (DataType, String) {
Expand Down

0 comments on commit 0a4493c

Please sign in to comment.