Skip to content

Commit

Permalink
fix returning of values vs array (#43)
Browse files Browse the repository at this point in the history
Co-authored-by: Samuel Colvin <[email protected]>
  • Loading branch information
adriangb and samuelcolvin authored Sep 9, 2024
1 parent 7cfd288 commit 3767061
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 29 deletions.
26 changes: 20 additions & 6 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::str::Utf8Error;
use datafusion::arrow::array::{
Array, ArrayRef, AsArray, Int64Array, LargeStringArray, StringArray, StringViewArray, UInt64Array,
};
use datafusion::arrow::compute::cast;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::error::ArrowError;
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};
Expand All @@ -20,12 +22,12 @@ pub fn return_type_check(args: &[DataType], fn_name: &str) -> DataFusionResult<(
let Some(first) = args.first() else {
return plan_err!("The '{fn_name}' function requires one or more arguments.");
};
if !(is_str(undict(first)) || is_json_union(first)) {
if !(is_str(unpack_dict_type(first)) || is_json_union(first)) {
// if !matches!(first, DataType::Utf8 | DataType::LargeUtf8) {
return plan_err!("Unexpected argument type to '{fn_name}' at position 1, expected a string, got {first:?}.");
}
args.iter().skip(1).enumerate().try_for_each(|(index, arg)| {
let t = undict(arg);
let t = unpack_dict_type(arg);
if is_str(t) || is_int(t) {
Ok(())
} else {
Expand All @@ -46,7 +48,16 @@ fn is_int(d: &DataType) -> bool {
matches!(d, DataType::UInt64 | DataType::Int64)
}

fn undict(d: &DataType) -> &DataType {
/// Convert a dict array to a non-dict array.
fn unpack_dict_array(array: ArrayRef) -> Result<ArrayRef, ArrowError> {
match array.data_type() {
DataType::Dictionary(_, value_type) => cast(array.as_ref(), value_type),
_ => Ok(array),
}
}

// if the type is a dict, return the value type, otherwise return the type
fn unpack_dict_type(d: &DataType) -> &DataType {
if let DataType::Dictionary(_, value) = d {
value.as_ref()
} else {
Expand Down Expand Up @@ -129,7 +140,8 @@ fn invoke_array<C: FromIterator<Option<I>> + 'static, I>(
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> DataFusionResult<ArrayRef> {
if let Some(d) = needle_array.as_any_dictionary_opt() {
invoke_array(json_array, d.values(), to_array, jiter_find)
let values = invoke_array(json_array, d.values(), to_array, jiter_find)?;
unpack_dict_array(d.with_values(values)).map_err(Into::into)
} else if let Some(str_path_array) = needle_array.as_any().downcast_ref::<StringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, to_array, jiter_find, true)
Expand Down Expand Up @@ -160,7 +172,8 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Optio
if let Some(d) = json_array.as_any_dictionary_opt() {
// NOTE we do NOT map back to an dictionary as that doesn't work for `is null` or filtering
// see https://github.com/apache/datafusion/issues/12380
return zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup);
let values = zip_apply(d.values(), path_array, to_array, jiter_find, object_lookup)?;
return unpack_dict_array(d.with_values(values)).map_err(Into::into);
}

let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
Expand Down Expand Up @@ -226,7 +239,8 @@ fn scalar_apply<C: FromIterator<Option<I>>, I>(
) -> DataFusionResult<ArrayRef> {
if let Some(d) = json_array.as_any_dictionary_opt() {
// as above, don't return a dict
return scalar_apply(d.values(), path, to_array, jiter_find);
let values = scalar_apply(d.values(), path, to_array, jiter_find)?;
return unpack_dict_array(d.with_values(values)).map_err(Into::into);
}

let c = if let Some(string_array) = json_array.as_any().downcast_ref::<StringArray>() {
Expand Down
124 changes: 103 additions & 21 deletions tests/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use datafusion::arrow::datatypes::DataType;
use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, RecordBatch};
use datafusion::arrow::datatypes::{Field, Int8Type, Schema};
use datafusion::arrow::{array::StringDictionaryBuilder, datatypes::DataType};
use datafusion::assert_batches_eq;
use datafusion::common::ScalarValue;
use datafusion::logical_expr::ColumnarValue;

use datafusion::prelude::SessionContext;
use datafusion_functions_json::udfs::json_get_str_udf;
use utils::{display_val, logical_plan, run_query, run_query_large, run_query_params};
use utils::{create_context, display_val, logical_plan, run_query, run_query_large, run_query_params};

mod utils;

Expand Down Expand Up @@ -197,11 +201,11 @@ async fn test_json_get_no_path() {
let batches = run_query(r#"select json_get('"foo"')::string"#).await.unwrap();
assert_eq!(display_val(batches).await, (DataType::Utf8, "foo".to_string()));

let batches = run_query(r#"select json_get('123')::int"#).await.unwrap();
let batches = run_query(r"select json_get('123')::int").await.unwrap();
assert_eq!(display_val(batches).await, (DataType::Int64, "123".to_string()));

let batches = run_query(r#"select json_get('true')::int"#).await.unwrap();
assert_eq!(display_val(batches).await, (DataType::Int64, "".to_string()));
let batches = run_query(r"select json_get('true')::int").await.unwrap();
assert_eq!(display_val(batches).await, (DataType::Int64, String::new()));
}

#[tokio::test]
Expand Down Expand Up @@ -350,7 +354,7 @@ async fn test_json_length_object() {
let batches = run_query(sql).await.unwrap();
assert_eq!(display_val(batches).await, (DataType::UInt64, "3".to_string()));

let sql = r#"select json_length('{}')"#;
let sql = r"select json_length('{}')";
let batches = run_query(sql).await.unwrap();
assert_eq!(display_val(batches).await, (DataType::UInt64, "0".to_string()));
}
Expand All @@ -359,7 +363,7 @@ async fn test_json_length_object() {
async fn test_json_length_string() {
let sql = r#"select json_length('"foobar"')"#;
let batches = run_query(sql).await.unwrap();
assert_eq!(display_val(batches).await, (DataType::UInt64, "".to_string()));
assert_eq!(display_val(batches).await, (DataType::UInt64, String::new()));
}

#[tokio::test]
Expand All @@ -370,7 +374,7 @@ async fn test_json_length_object_nested() {

let sql = r#"select json_length('{"a": 1, "b": 2, "c": []}', 'b')"#;
let batches = run_query(sql).await.unwrap();
assert_eq!(display_val(batches).await, (DataType::UInt64, "".to_string()));
assert_eq!(display_val(batches).await, (DataType::UInt64, String::new()));
}

#[tokio::test]
Expand Down Expand Up @@ -455,7 +459,7 @@ async fn test_json_contains_large_both_params() {

#[tokio::test]
async fn test_json_length_vec() {
let sql = r#"select name, json_len(json_data) as len from test"#;
let sql = r"select name, json_len(json_data) as len from test";
let batches = run_query(sql).await.unwrap();

let expected = [
Expand All @@ -479,7 +483,7 @@ async fn test_json_length_vec() {

#[tokio::test]
async fn test_no_args() {
let err = run_query(r#"select json_len()"#).await.unwrap_err();
let err = run_query(r"select json_len()").await.unwrap_err();
assert!(err
.to_string()
.contains("No function matches the given name and argument types 'json_length()'."));
Expand Down Expand Up @@ -562,10 +566,10 @@ async fn test_json_get_nested_collapsed() {
#[tokio::test]
async fn test_json_get_cte() {
// avoid auto-un-nesting with a CTE
let sql = r#"
let sql = r"
with t as (select name, json_get(json_data, 'foo') j from test)
select name, json_get(j, 0) v from t
"#;
";
let expected = [
"+------------------+---------+",
"| name | v |",
Expand All @@ -587,11 +591,11 @@ async fn test_json_get_cte() {
#[tokio::test]
async fn test_plan_json_get_cte() {
// avoid auto-unnesting with a CTE
let sql = r#"
let sql = r"
explain
with t as (select name, json_get(json_data, 'foo') j from test)
select name, json_get(j, 0) v from t
"#;
";
let expected = [
"Projection: t.name, json_get(t.j, Int64(0)) AS v",
" SubqueryAlias: t",
Expand Down Expand Up @@ -751,7 +755,7 @@ async fn test_arrow() {

#[tokio::test]
async fn test_plan_arrow() {
let lines = logical_plan(r#"explain select json_data->'foo' from test"#).await;
let lines = logical_plan(r"explain select json_data->'foo' from test").await;

let expected = [
"Projection: json_get(test.json_data, Utf8(\"foo\")) AS test.json_data -> Utf8(\"foo\")",
Expand Down Expand Up @@ -783,7 +787,7 @@ async fn test_long_arrow() {

#[tokio::test]
async fn test_plan_long_arrow() {
let lines = logical_plan(r#"explain select json_data->>'foo' from test"#).await;
let lines = logical_plan(r"explain select json_data->>'foo' from test").await;

let expected = [
"Projection: json_as_text(test.json_data, Utf8(\"foo\")) AS test.json_data ->> Utf8(\"foo\")",
Expand Down Expand Up @@ -834,7 +838,7 @@ async fn test_arrow_cast_int() {

#[tokio::test]
async fn test_plan_arrow_cast_int() {
let lines = logical_plan(r#"explain select (json_data->'foo')::int from test"#).await;
let lines = logical_plan(r"explain select (json_data->'foo')::int from test").await;

let expected = [
"Projection: json_get_int(test.json_data, Utf8(\"foo\")) AS test.json_data -> Utf8(\"foo\")",
Expand Down Expand Up @@ -866,7 +870,7 @@ async fn test_arrow_double_nested() {

#[tokio::test]
async fn test_plan_arrow_double_nested() {
let lines = logical_plan(r#"explain select json_data->'foo'->0 from test"#).await;
let lines = logical_plan(r"explain select json_data->'foo'->0 from test").await;

let expected = [
"Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> Utf8(\"foo\") -> Int64(0)",
Expand Down Expand Up @@ -900,7 +904,7 @@ async fn test_arrow_double_nested_cast() {

#[tokio::test]
async fn test_plan_arrow_double_nested_cast() {
let lines = logical_plan(r#"explain select (json_data->'foo'->0)::int from test"#).await;
let lines = logical_plan(r"explain select (json_data->'foo'->0)::int from test").await;

let expected = [
"Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS test.json_data -> Utf8(\"foo\") -> Int64(0)",
Expand Down Expand Up @@ -948,7 +952,7 @@ async fn test_arrow_nested_double_columns() {
async fn test_lexical_precedence_wrong() {
let sql = r#"select '{"a": "b"}'->>'a'='b' as v"#;
let err = run_query(sql).await.unwrap_err();
assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_as_text' at position 2, expected string or int, got Boolean.")
assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_as_text' at position 2, expected string or int, got Boolean.");
}

#[tokio::test]
Expand Down Expand Up @@ -1261,3 +1265,81 @@ async fn test_dict_get_int() {
let batches = run_query(sql).await.unwrap();
assert_batches_eq!(expected, &batches);
}

async fn build_dict_schema() -> SessionContext {
let mut builder = StringDictionaryBuilder::<Int8Type>::new();
builder.append(r#"{"foo": "bar"}"#).unwrap();
builder.append(r#"{"baz": "fizz"}"#).unwrap();
builder.append("nah").unwrap();
builder.append(r#"{"baz": "abcd"}"#).unwrap();
builder.append_null();
builder.append(r#"{"baz": "fizz"}"#).unwrap();
builder.append(r#"{"baz": "fizz"}"#).unwrap();
builder.append(r#"{"baz": "fizz"}"#).unwrap();
builder.append(r#"{"baz": "fizz"}"#).unwrap();
builder.append_null();

let dict = builder.finish();
let array = Arc::new(dict) as ArrayRef;

let schema = Arc::new(Schema::new(vec![Field::new(
"x",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)),
true,
)]));

let data = RecordBatch::try_new(schema.clone(), vec![array]).unwrap();

let ctx = create_context().await.unwrap();
ctx.register_batch("data", data).unwrap();
ctx
}

#[tokio::test]
async fn test_dict_filter() {
let ctx = build_dict_schema().await;

let sql = "select json_get(x, 'baz') v from data";
let expected = [
"+------------+",
"| v |",
"+------------+",
"| {null=} |",
"| {str=fizz} |",
"| {null=} |",
"| {str=abcd} |",
"| {null=} |",
"| {str=fizz} |",
"| {str=fizz} |",
"| {str=fizz} |",
"| {str=fizz} |",
"| {null=} |",
"+------------+",
];

let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap();

assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_dict_filter_is_not_null() {
let ctx = build_dict_schema().await;
let sql = "select x from data where json_get(x, 'baz') is not null";
let expected = [
"+-----------------+",
"| x |",
"+-----------------+",
"| {\"baz\": \"fizz\"} |",
"| {\"baz\": \"abcd\"} |",
"| {\"baz\": \"fizz\"} |",
"| {\"baz\": \"fizz\"} |",
"| {\"baz\": \"fizz\"} |",
"| {\"baz\": \"fizz\"} |",
"+-----------------+",
];

let batches = ctx.sql(sql).await.unwrap().collect().await.unwrap();

assert_batches_eq!(expected, &batches);
}
9 changes: 7 additions & 2 deletions tests/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@ use datafusion::execution::context::SessionContext;
use datafusion::prelude::SessionConfig;
use datafusion_functions_json::register_all;

async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
pub async fn create_context() -> Result<SessionContext> {
let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres");
let mut ctx = SessionContext::new_with_config(config);
register_all(&mut ctx)?;
Ok(ctx)
}

async fn create_test_table(large_utf8: bool) -> Result<SessionContext> {
let ctx = create_context().await?;

let test_data = [
("object_foo", r#" {"foo": "abc"} "#),
Expand Down Expand Up @@ -214,5 +219,5 @@ pub async fn logical_plan(sql: &str) -> Vec<String> {
let batches = run_query(sql).await.unwrap();
let plan_col = batches[0].column(1).as_any().downcast_ref::<StringArray>().unwrap();
let logical_plan = plan_col.value(0);
logical_plan.split('\n').map(|s| s.to_string()).collect()
logical_plan.split('\n').map(ToString::to_string).collect()
}

0 comments on commit 3767061

Please sign in to comment.