Skip to content

Commit

Permalink
Merge branch 'main' into allow-json-contains-filter
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Sep 17, 2024
2 parents 3767061 + 5738ff3 commit 85ac665
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 52 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
Cargo.lock
.idea
45 changes: 44 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,47 @@ To use these functions, you'll just need to call:
```rust
datafusion_functions_json::register_all(&mut ctx)?;
```

To register the below JSON functions in your `SessionContext`.

# Examples

```sql
-- Create a table with a JSON column stored as a string
CREATE TABLE test_table (id INT, json_col VARCHAR) AS VALUES
(1, '{}'),
(2, '{ "a": 1 }'),
(3, '{ "a": 2 }'),
(4, '{ "a": 1, "b": 2 }'),
(5, '{ "a": 1, "b": 2, "c": 3 }');

-- Check if each document contains the key 'b'
SELECT id, json_contains(json_col, 'b') as json_contains FROM test_table;
-- Results in
-- +----+---------------+
-- | id | json_contains |
-- +----+---------------+
-- | 1 | false |
-- | 2 | false |
-- | 3 | false |
-- | 4 | true |
-- | 5 | true |
-- +----+---------------+

-- Get the value of the key 'a' from each document
SELECT id, json_col->'a' as json_col_a FROM test_table

-- +----+------------+
-- | id | json_col_a |
-- +----+------------+
-- | 1 | {null=} |
-- | 2 | {int=1} |
-- | 3 | {int=2} |
-- | 4 | {int=1} |
-- | 5 | {int=1} |
-- +----+------------+
```


## Done

* [x] `json_contains(json: str, *keys: str | int) -> bool` - true if a JSON string has a specific key (used for the `?` operator)
Expand All @@ -27,6 +65,11 @@ To register the below JSON functions in your `SessionContext`.
* [x] `json_as_text(json: str, *keys: str | int) -> str` - Get any value from a JSON string by its "path", represented as a string (used for the `->>` operator)
* [x] `json_length(json: str, *keys: str | int) -> int` - get the length of a JSON string or array

- [x] `->` operator - alias for `json_get`
- [x] `->>` operator - alias for `json_as_text`
- [x] `?` operator - alias for `json_contains`

### Notes
Cast expressions with `json_get` are rewritten to the appropriate method, e.g.

```sql
Expand Down
123 changes: 87 additions & 36 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
use std::str::Utf8Error;
use std::sync::Arc;

use datafusion::arrow::array::{
Array, ArrayRef, AsArray, Int64Array, LargeStringArray, StringArray, StringViewArray, UInt64Array,
Array, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, StringArray,
StringViewArray, UInt64Array, UnionArray,
};
use datafusion::arrow::compute::cast;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::error::ArrowError;
use datafusion::arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType};
use datafusion::arrow::downcast_dictionary_array;
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};

use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array};
use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL};

/// General implementation of `ScalarUDFImpl::return_type` to check if the arguments are valid.
///
Expand Down Expand Up @@ -169,24 +172,39 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
object_lookup: bool,
) -> DataFusionResult<ArrayRef> {
if let Some(d) = json_array.as_any_dictionary_opt() {
// NOTE we do NOT map back to an dictionary as that doesn't work for `is null` or filtering
// see https://github.com/apache/datafusion/issues/12380
let values = zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup)?;
return unpack_dict_array(d.with_values(values)).map_err(Into::into);
}
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
use datafusion::arrow::datatypes as arrow_schema;

let c = downcast_dictionary_array!(
json_array => {
let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup)?;
if !is_json_union(values.data_type()) {
return Ok(Arc::new(json_array.with_values(values)));
}
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
return Ok(Arc::new(DictionaryArray::new(
mask_dictionary_keys(json_array.keys(), type_ids),
values,
)));
}
DataType::Utf8 => zip_apply_iter(json_array.as_string::<i32>().iter(), path_array, jiter_find),
DataType::LargeUtf8 => zip_apply_iter(json_array.as_string::<i64>().iter(), path_array, jiter_find),
DataType::Utf8View => zip_apply_iter(json_array.as_string_view().iter(), path_array, jiter_find),
other => if let Some(string_array) = nested_json_array(json_array, object_lookup) {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", other);
}
);

// if let Some(d) = json_array.as_any_dictionary_opt() {
// // NOTE we do NOT map back to an dictionary as that doesn't work for `is null` or filtering
// // see https://github.com/apache/datafusion/issues/12380
// let values = zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup)?;
// return unpack_dict_array(d.with_values(values)).map_err(Into::into);
// }

let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
zip_apply_iter(large_string_array.iter(), path_array, jiter_find)
} else if let Some(string_view) = json_array.as_any().downcast_ref::<StringViewArray>() {
zip_apply_iter(string_view.iter(), path_array, jiter_find)
} else if let Some(string_array) = nested_json_array(json_array, object_lookup) {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", json_array.data_type());
};
to_array(c)
}

Expand Down Expand Up @@ -237,24 +255,37 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> DataFusionResult<ArrayRef> {
if let Some(d) = json_array.as_any_dictionary_opt() {
// as above, don't return a dict
let values = scalar_apply(d.values(), path, to_array, jiter_find)?;
return unpack_dict_array(d.with_values(values)).map_err(Into::into);
}
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
use datafusion::arrow::datatypes as arrow_schema;

let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
scalar_apply_iter(large_string_array.iter(), path, jiter_find)
} else if let Some(string_view_array) = json_array.as_any().downcast_ref::<StringViewArray>() {
scalar_apply_iter(string_view_array.iter(), path, jiter_find)
} else if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", json_array.data_type());
};
let c = downcast_dictionary_array!(
json_array => {
let values = scalar_apply(json_array.values(), path, to_array, jiter_find)?;
if !is_json_union(values.data_type()) {
return Ok(Arc::new(json_array.with_values(values)));
}
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
return Ok(Arc::new(DictionaryArray::new(
mask_dictionary_keys(json_array.keys(), type_ids),
values,
)));
}
DataType::Utf8 => scalar_apply_iter(json_array.as_string::<i32>().iter(), path, jiter_find),
DataType::LargeUtf8 => scalar_apply_iter(json_array.as_string::<i64>().iter(), path, jiter_find),
DataType::Utf8View => scalar_apply_iter(json_array.as_string_view().iter(), path, jiter_find),
other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", other);
}
);

// if let Some(d) = json_array.as_any_dictionary_opt() {
// // as above, don't return a dict
// let values = scalar_apply(d.values(), path, to_array, jiter_find)?;
// return unpack_dict_array(d.with_values(values)).map_err(Into::into);
// }
to_array(c)
}

Expand Down Expand Up @@ -328,3 +359,23 @@ impl From<Utf8Error> for GetError {
GetError
}
}

/// Set keys to null where the union member is null.
///
/// This is a workaround to <https://github.com/apache/arrow-rs/issues/6017#issuecomment-2352756753>
/// - i.e. that dictionary null is most reliably done if the keys are null.
///
/// That said, doing this might also be an optimization for cases like null-checking without needing
/// to check the value union array.
fn mask_dictionary_keys<K: ArrowPrimitiveType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
let mut null_mask = vec![true; keys.len()];
for (i, k) in keys.iter().enumerate() {
match k {
// if the key is non-null and value is non-null, don't mask it out
Some(k) if type_ids[k.as_usize()] != TYPE_ID_NULL => {}
// i.e. key is null or value is null here
_ => null_mask[i] = false,
}
}
PrimitiveArray::new(keys.values().clone(), Some(null_mask.into()))
}
2 changes: 1 addition & 1 deletion src/common_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ pub(crate) enum JsonUnionField {
Object(String),
}

const TYPE_ID_NULL: i8 = 0;
pub(crate) const TYPE_ID_NULL: i8 = 0;
const TYPE_ID_BOOL: i8 = 1;
const TYPE_ID_INT: i8 = 2;
const TYPE_ID_FLOAT: i8 = 3;
Expand Down
56 changes: 49 additions & 7 deletions tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use datafusion::common::ScalarValue;
use datafusion::logical_expr::ColumnarValue;
use datafusion::prelude::SessionContext;
use datafusion_functions_json::udfs::json_get_str_udf;
use utils::{create_context, display_val, logical_plan, run_query, run_query_large, run_query_params};
use utils::{create_context, display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params};

mod utils;

Expand Down Expand Up @@ -1076,6 +1076,28 @@ async fn test_arrow_union_is_null() {
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_null_dict_encoded() {
let batches = run_query_dict("select name, (json_data->'foo') is null from test")
.await
.unwrap();

let expected = [
"+------------------+---------------------------------------+",
"| name | test.json_data -> Utf8(\"foo\") IS NULL |",
"+------------------+---------------------------------------+",
"| object_foo | false |",
"| object_foo_array | false |",
"| object_foo_obj | false |",
"| object_foo_null | true |",
"| object_bar | true |",
"| list_foo | true |",
"| invalid_json | true |",
"+------------------+---------------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_not_null() {
let batches = run_query("select name, (json_data->'foo') is not null from test")
Expand All @@ -1098,6 +1120,28 @@ async fn test_arrow_union_is_not_null() {
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_not_null_dict_encoded() {
let batches = run_query_dict("select name, (json_data->'foo') is not null from test")
.await
.unwrap();

let expected = [
"+------------------+-------------------------------------------+",
"| name | test.json_data -> Utf8(\"foo\") IS NOT NULL |",
"+------------------+-------------------------------------------+",
"| object_foo | true |",
"| object_foo_array | true |",
"| object_foo_obj | true |",
"| object_foo_null | false |",
"| object_bar | false |",
"| list_foo | false |",
"| invalid_json | false |",
"+------------------+-------------------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_scalar_union_is_null() {
let batches = run_query(
Expand Down Expand Up @@ -1151,9 +1195,8 @@ async fn test_dict_haystack() {
"| v |",
"+-----------------------+",
"| {object={\"bar\": [0]}} |",
"| {null=} |",
"| {null=} |",
"| {null=} |",
"| |",
"| |",
"+-----------------------+",
];

Expand Down Expand Up @@ -1184,9 +1227,8 @@ async fn test_dict_haystack_needle() {
"| v |",
"+-------------+",
"| {array=[0]} |",
"| {null=} |",
"| {null=} |",
"| {null=} |",
"| |",
"| |",
"+-------------+",
];

Expand Down
28 changes: 21 additions & 7 deletions tests/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
use std::sync::Arc;

use datafusion::arrow::array::{
ArrayRef, DictionaryArray, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array,
ArrayRef, DictionaryArray, Int32Array, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array,
};
use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, UInt32Type, UInt8Type};
use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Int64Type, Schema, UInt32Type, UInt8Type};
use datafusion::arrow::util::display::{ArrayFormatter, FormatOptions};
use datafusion::arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch};
use datafusion::common::ParamValues;
Expand All @@ -20,7 +20,7 @@ pub async fn create_context() -> Result<SessionContext> {
Ok(ctx)
}

async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result<SessionContext> {
let ctx = create_context().await?;

let test_data = [
Expand All @@ -33,11 +33,20 @@ async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
("invalid_json", "is not json"),
];
let json_values = test_data.iter().map(|(_, json)| *json).collect::<Vec<_>>();
let (json_data_type, json_array): (DataType, ArrayRef) = if large_utf8 {
let (mut json_data_type, mut json_array): (DataType, ArrayRef) = if large_utf8 {
(DataType::LargeUtf8, Arc::new(LargeStringArray::from(json_values)))
} else {
(DataType::Utf8, Arc::new(StringArray::from(json_values)))
};

if dict_encoded {
json_data_type = DataType::Dictionary(DataType::Int32.into(), json_data_type.into());
json_array = Arc::new(DictionaryArray::<Int32Type>::new(
Int32Array::from_iter_values(0..(json_array.len() as i32)),
json_array,
));
}

let test_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Expand Down Expand Up @@ -184,12 +193,17 @@ async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
}

pub async fn run_query(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(false).await?;
let ctx = create_test_table(false, false).await?;
ctx.sql(sql).await?.collect().await
}

pub async fn run_query_large(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(true).await?;
let ctx = create_test_table(true, false).await?;
ctx.sql(sql).await?.collect().await
}

pub async fn run_query_dict(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(false, true).await?;
ctx.sql(sql).await?.collect().await
}

Expand All @@ -198,7 +212,7 @@ pub async fn run_query_params(
large_utf8: bool,
query_values: impl Into<ParamValues>,
) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(large_utf8).await?;
let ctx = create_test_table(large_utf8, false).await?;
ctx.sql(sql).await?.with_param_values(query_values)?.collect().await
}

Expand Down

0 comments on commit 85ac665

Please sign in to comment.