Skip to content

Commit

Permalink
support lookup by column
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Apr 23, 2024
1 parent fa70af1 commit 23c9402
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 25 deletions.
50 changes: 40 additions & 10 deletions src/json_get_int.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::any::Any;
use std::sync::Arc;

use arrow::array::{as_string_array, Int64Array};
use arrow::array::{as_string_array, Int64Array, StringArray};
use arrow_schema::DataType;
use datafusion_common::arrow::array::ArrayRef;
use datafusion_common::{exec_err, Result as DataFusionResult, ScalarValue};
Expand Down Expand Up @@ -51,18 +51,34 @@ impl ScalarUDFImpl for JsonGetInt {
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
let path = JsonPath::extract_args(args);

match &args[0] {
ColumnarValue::Array(array) => {
let array = as_string_array(array)
.iter()
.map(|opt_json| jiter_json_get_int(opt_json, &path).ok())
.collect::<Int64Array>();

Ok(ColumnarValue::from(Arc::new(array) as ArrayRef))
ColumnarValue::Array(json_array) => {
let result_array = match &args[1] {
ColumnarValue::Array(a) => {
if let Some(str_path_array) = a.as_any().downcast_ref::<StringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(|s| JsonPath::Key(s)));
zip_apply(json_array, paths)
} else if let Some(int_path_array) = a.as_any().downcast_ref::<Int64Array>() {
let paths = int_path_array
.iter()
.map(|opt_index| opt_index.map(|i| JsonPath::Index(i as usize)));
zip_apply(json_array, paths)
} else {
return exec_err!("unexpected second argument type, expected string or int array");
}
}
_ => {
let path = JsonPath::extract_args(args);
as_string_array(json_array)
.iter()
.map(|opt_json| jiter_json_get_int(opt_json, &path).ok())
.collect::<Int64Array>()
}
};
Ok(ColumnarValue::from(Arc::new(result_array) as ArrayRef))
}
ColumnarValue::Scalar(ScalarValue::Utf8(s)) => {
let path = JsonPath::extract_args(args);
let v = jiter_json_get_int(s.as_ref().map(|s| s.as_str()), &path).ok();
Ok(ColumnarValue::Scalar(ScalarValue::Int64(v)))
}
Expand All @@ -77,6 +93,20 @@ impl ScalarUDFImpl for JsonGetInt {
}
}

fn zip_apply<'a, T: Iterator<Item = Option<JsonPath<'a>>>>(json_array: &ArrayRef, paths: T) -> Int64Array {
as_string_array(json_array)
.iter()
.zip(paths)
.map(|(opt_json, opt_path)| {
if let Some(path) = opt_path {
jiter_json_get_int(opt_json, &[path]).ok()
} else {
None
}
})
.collect::<Int64Array>()
}

fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath]) -> Result<i64, GetError> {
if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) {
match peek {
Expand Down
31 changes: 28 additions & 3 deletions tests/test_json_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ async fn test_json_get_str() {

#[tokio::test]
async fn test_json_get_str_equals() {
let batches = run_query("select name, json_get_str(json_data, 'foo')='abc' from test")
.await
.unwrap();
let sql = "select name, json_get_str(json_data, 'foo')='abc' from test";
let batches = run_query(sql).await.unwrap();

let expected = [
"+------------------+--------------------------------------------------------+",
Expand Down Expand Up @@ -167,6 +166,32 @@ async fn test_json_get_cast_int_path() {
assert_eq!(display_val(batches).await, (DataType::Int64, "73".to_string()));
}

#[tokio::test]
async fn test_json_get_int_lookup() {
let sql = "select str_key, json_data from other where json_get_int(json_data, str_key) is not null";
let batches = run_query(sql).await.unwrap();
let expected = [
"+---------+---------------+",
"| str_key | json_data |",
"+---------+---------------+",
"| foo | {\"foo\": 42} |",
"+---------+---------------+",
];
assert_batches_eq!(expected, &batches);

// lookup by int
let sql = "select int_key, json_data from other where json_get_int(json_data, int_key) is not null";
let batches = run_query(sql).await.unwrap();
let expected = [
"+---------+-----------+",
"| int_key | json_data |",
"+---------+-----------+",
"| 0 | [42] |",
"+---------+-----------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_get_float() {
let batches = run_query(r#"select json_get_float('[1.5]', 0) as v"#).await.unwrap();
Expand Down
49 changes: 37 additions & 12 deletions tests/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(dead_code)]
use arrow::array::Int64Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::util::display::{ArrayFormatter, FormatOptions};
use arrow::{array::StringArray, record_batch::RecordBatch};
Expand All @@ -9,12 +10,10 @@ use datafusion::execution::context::SessionContext;
use datafusion_functions_json::register_all;

async fn create_test_table() -> Result<SessionContext> {
let schema = Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("json_data", DataType::Utf8, false),
]));
let mut ctx = SessionContext::new();
register_all(&mut ctx)?;

let data = [
let test_data = [
("object_foo", r#" {"foo": "abc"} "#),
("object_foo_array", r#" {"foo": [1]} "#),
("object_foo_obj", r#" {"foo": {}} "#),
Expand All @@ -23,22 +22,48 @@ async fn create_test_table() -> Result<SessionContext> {
("list_foo", r#" ["foo"] "#),
("invalid_json", "is not json"),
];
let test_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("json_data", DataType::Utf8, false),
])),
vec![
Arc::new(StringArray::from(
test_data.iter().map(|(name, _)| *name).collect::<Vec<_>>(),
)),
Arc::new(StringArray::from(
test_data.iter().map(|(_, json)| *json).collect::<Vec<_>>(),
)),
],
)?;
ctx.register_batch("test", test_batch)?;

let batch = RecordBatch::try_new(
schema,
let other_data = [
(r#" {"foo": 42} "#, "foo", 0),
(r#" {"foo": 42} "#, "bar", 1),
(r#" [42] "#, "foo", 0),
(r#" [42] "#, "bar", 1),
];
let other_batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new("json_data", DataType::Utf8, false),
Field::new("str_key", DataType::Utf8, false),
Field::new("int_key", DataType::Int64, false),
])),
vec![
Arc::new(StringArray::from(
data.iter().map(|(name, _)| *name).collect::<Vec<_>>(),
other_data.iter().map(|(json, _, _)| *json).collect::<Vec<_>>(),
)),
Arc::new(StringArray::from(
data.iter().map(|(_, json)| *json).collect::<Vec<_>>(),
other_data.iter().map(|(_, str_key, _)| *str_key).collect::<Vec<_>>(),
)),
Arc::new(Int64Array::from(
other_data.iter().map(|(_, _, int_key)| *int_key).collect::<Vec<_>>(),
)),
],
)?;
ctx.register_batch("other", other_batch)?;

let mut ctx = SessionContext::new();
register_all(&mut ctx)?;
ctx.register_batch("test", batch)?;
Ok(ctx)
}

Expand Down

0 comments on commit 23c9402

Please sign in to comment.