Skip to content

Commit

Permalink
set keys to null where applicable in dictionary-encoded results (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt authored Sep 16, 2024
1 parent 4a33b9f commit f3d5366
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 47 deletions.
109 changes: 76 additions & 33 deletions src/common.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use std::str::Utf8Error;
use std::sync::Arc;

use datafusion::arrow::array::{
Array, ArrayRef, AsArray, Int64Array, LargeStringArray, StringArray, StringViewArray, UInt64Array,
Array, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, StringArray,
StringViewArray, UInt64Array, UnionArray,
};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType};
use datafusion::arrow::downcast_dictionary_array;
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};

use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array};
use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL};

/// General implementation of `ScalarUDFImpl::return_type`.
///
Expand Down Expand Up @@ -164,21 +167,32 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
object_lookup: bool,
) -> DataFusionResult<ArrayRef> {
if let Some(d) = json_array.as_any_dictionary_opt() {
let a = zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup)?;
return Ok(d.with_values(a).into());
}
let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
zip_apply_iter(large_string_array.iter(), path_array, jiter_find)
} else if let Some(string_view) = json_array.as_any().downcast_ref::<StringViewArray>() {
zip_apply_iter(string_view.iter(), path_array, jiter_find)
} else if let Some(string_array) = nested_json_array(json_array, object_lookup) {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", json_array.data_type());
};
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
use datafusion::arrow::datatypes as arrow_schema;

let c = downcast_dictionary_array!(
json_array => {
let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup)?;
if !is_json_union(values.data_type()) {
return Ok(Arc::new(json_array.with_values(values)));
}
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
return Ok(Arc::new(DictionaryArray::new(
mask_dictionary_keys(json_array.keys(), type_ids),
values,
)));
}
DataType::Utf8 => zip_apply_iter(json_array.as_string::<i32>().iter(), path_array, jiter_find),
DataType::LargeUtf8 => zip_apply_iter(json_array.as_string::<i64>().iter(), path_array, jiter_find),
DataType::Utf8View => zip_apply_iter(json_array.as_string_view().iter(), path_array, jiter_find),
other => if let Some(string_array) = nested_json_array(json_array, object_lookup) {
zip_apply_iter(string_array.iter(), path_array, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", other);
}
);

to_array(c)
}

Expand Down Expand Up @@ -229,22 +243,31 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> DataFusionResult<ArrayRef> {
if let Some(d) = json_array.as_any_dictionary_opt() {
let a = scalar_apply(d.values(), path, to_array, jiter_find)?;
return Ok(d.with_values(a).into());
}
// arrow_schema "use" is workaround for https://github.com/apache/arrow-rs/issues/6400#issue-2528388332
use datafusion::arrow::datatypes as arrow_schema;

let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else if let Some(large_string_array) = json_array.as_any().downcast_ref::<LargeStringArray>() {
scalar_apply_iter(large_string_array.iter(), path, jiter_find)
} else if let Some(string_view_array) = json_array.as_any().downcast_ref::<StringViewArray>() {
scalar_apply_iter(string_view_array.iter(), path, jiter_find)
} else if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", json_array.data_type());
};
let c = downcast_dictionary_array!(
json_array => {
let values = scalar_apply(json_array.values(), path, to_array, jiter_find)?;
if !is_json_union(values.data_type()) {
return Ok(Arc::new(json_array.with_values(values)));
}
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
return Ok(Arc::new(DictionaryArray::new(
mask_dictionary_keys(json_array.keys(), type_ids),
values,
)));
}
DataType::Utf8 => scalar_apply_iter(json_array.as_string::<i32>().iter(), path, jiter_find),
DataType::LargeUtf8 => scalar_apply_iter(json_array.as_string::<i64>().iter(), path, jiter_find),
DataType::Utf8View => scalar_apply_iter(json_array.as_string_view().iter(), path, jiter_find),
other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
scalar_apply_iter(string_array.iter(), path, jiter_find)
} else {
return exec_err!("unexpected json array type {:?}", other);
}
);

to_array(c)
}
Expand Down Expand Up @@ -319,3 +342,23 @@ impl From<Utf8Error> for GetError {
GetError
}
}

/// Set keys to null where the union member is null.
///
/// This is a workaround to <https://github.com/apache/arrow-rs/issues/6017#issuecomment-2352756753>
/// - i.e. that dictionary null is most reliably done if the keys are null.
///
/// That said, doing this might also be an optimization for cases like null-checking without needing
/// to check the value union array.
fn mask_dictionary_keys<K: ArrowPrimitiveType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
let mut null_mask = vec![true; keys.len()];
for (i, k) in keys.iter().enumerate() {
match k {
// if the key is non-null and value is non-null, don't mask it out
Some(k) if type_ids[k.as_usize()] != TYPE_ID_NULL => {}
// i.e. key is null or value is null here
_ => null_mask[i] = false,
}
}
PrimitiveArray::new(keys.values().clone(), Some(null_mask.into()))
}
2 changes: 1 addition & 1 deletion src/common_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ pub(crate) enum JsonUnionField {
Object(String),
}

const TYPE_ID_NULL: i8 = 0;
pub(crate) const TYPE_ID_NULL: i8 = 0;
const TYPE_ID_BOOL: i8 = 1;
const TYPE_ID_INT: i8 = 2;
const TYPE_ID_FLOAT: i8 = 3;
Expand Down
2 changes: 1 addition & 1 deletion src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<Transformed<Expr>> {
fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
match expr {
Expr::ScalarFunction(func) => Some(func),
Expr::Alias(alias) => extract_scalar_function(&*alias.expr),
Expr::Alias(alias) => extract_scalar_function(&alias.expr),
_ => None,
}
}
Expand Down
54 changes: 49 additions & 5 deletions tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use datafusion::common::ScalarValue;
use datafusion::logical_expr::ColumnarValue;

use datafusion_functions_json::udfs::json_get_str_udf;
use utils::{display_val, logical_plan, run_query, run_query_large, run_query_params};
use utils::{display_val, logical_plan, run_query, run_query_dict, run_query_large, run_query_params};

mod utils;

Expand Down Expand Up @@ -1072,6 +1072,28 @@ async fn test_arrow_union_is_null() {
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_null_dict_encoded() {
let batches = run_query_dict("select name, (json_data->'foo') is null from test")
.await
.unwrap();

let expected = [
"+------------------+---------------------------------------+",
"| name | test.json_data -> Utf8(\"foo\") IS NULL |",
"+------------------+---------------------------------------+",
"| object_foo | false |",
"| object_foo_array | false |",
"| object_foo_obj | false |",
"| object_foo_null | true |",
"| object_bar | true |",
"| list_foo | true |",
"| invalid_json | true |",
"+------------------+---------------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_not_null() {
let batches = run_query("select name, (json_data->'foo') is not null from test")
Expand All @@ -1094,6 +1116,28 @@ async fn test_arrow_union_is_not_null() {
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_not_null_dict_encoded() {
let batches = run_query_dict("select name, (json_data->'foo') is not null from test")
.await
.unwrap();

let expected = [
"+------------------+-------------------------------------------+",
"| name | test.json_data -> Utf8(\"foo\") IS NOT NULL |",
"+------------------+-------------------------------------------+",
"| object_foo | true |",
"| object_foo_array | true |",
"| object_foo_obj | true |",
"| object_foo_null | false |",
"| object_bar | false |",
"| list_foo | false |",
"| invalid_json | false |",
"+------------------+-------------------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_scalar_union_is_null() {
let batches = run_query(
Expand Down Expand Up @@ -1147,8 +1191,8 @@ async fn test_dict_haystack() {
"| v |",
"+-----------------------+",
"| {object={\"bar\": [0]}} |",
"| {null=} |",
"| {null=} |",
"| |",
"| |",
"+-----------------------+",
];

Expand All @@ -1164,8 +1208,8 @@ async fn test_dict_haystack_needle() {
"| v |",
"+-------------+",
"| {array=[0]} |",
"| {null=} |",
"| {null=} |",
"| |",
"| |",
"+-------------+",
];

Expand Down
28 changes: 21 additions & 7 deletions tests/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
use std::sync::Arc;

use datafusion::arrow::array::{
ArrayRef, DictionaryArray, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array,
ArrayRef, DictionaryArray, Int32Array, Int64Array, StringViewArray, UInt32Array, UInt64Array, UInt8Array,
};
use datafusion::arrow::datatypes::{DataType, Field, Int64Type, Schema, UInt32Type, UInt8Type};
use datafusion::arrow::datatypes::{DataType, Field, Int32Type, Int64Type, Schema, UInt32Type, UInt8Type};
use datafusion::arrow::util::display::{ArrayFormatter, FormatOptions};
use datafusion::arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBatch};
use datafusion::common::ParamValues;
Expand All @@ -13,7 +13,7 @@ use datafusion::execution::context::SessionContext;
use datafusion::prelude::SessionConfig;
use datafusion_functions_json::register_all;

async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
async fn create_test_table(large_utf8: bool, dict_encoded: bool) -> Result<SessionContext> {
let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres");
let mut ctx = SessionContext::new_with_config(config);
register_all(&mut ctx)?;
Expand All @@ -28,11 +28,20 @@ async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
("invalid_json", "is not json"),
];
let json_values = test_data.iter().map(|(_, json)| *json).collect::<Vec<_>>();
let (json_data_type, json_array): (DataType, ArrayRef) = if large_utf8 {
let (mut json_data_type, mut json_array): (DataType, ArrayRef) = if large_utf8 {
(DataType::LargeUtf8, Arc::new(LargeStringArray::from(json_values)))
} else {
(DataType::Utf8, Arc::new(StringArray::from(json_values)))
};

if dict_encoded {
json_data_type = DataType::Dictionary(DataType::Int32.into(), json_data_type.into());
json_array = Arc::new(DictionaryArray::<Int32Type>::new(
Int32Array::from_iter_values(0..(json_array.len() as i32)),
json_array,
));
}

let test_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Expand Down Expand Up @@ -178,12 +187,17 @@ async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
}

pub async fn run_query(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(false).await?;
let ctx = create_test_table(false, false).await?;
ctx.sql(sql).await?.collect().await
}

pub async fn run_query_large(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(true).await?;
let ctx = create_test_table(true, false).await?;
ctx.sql(sql).await?.collect().await
}

pub async fn run_query_dict(sql: &str) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(false, true).await?;
ctx.sql(sql).await?.collect().await
}

Expand All @@ -192,7 +206,7 @@ pub async fn run_query_params(
large_utf8: bool,
query_values: impl Into<ParamValues>,
) -> Result<Vec<RecordBatch>> {
let ctx = create_test_table(large_utf8).await?;
let ctx = create_test_table(large_utf8, false).await?;
ctx.sql(sql).await?.with_param_values(query_values)?.collect().await
}

Expand Down

0 comments on commit f3d5366

Please sign in to comment.