Skip to content

Commit

Permalink
cleanup and improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Apr 23, 2024
1 parent 8a31e63 commit 279460e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 42 deletions.
41 changes: 28 additions & 13 deletions src/common_get.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,43 @@
use arrow_schema::DataType;
use std::str::Utf8Error;

use datafusion_common::{exec_err, Result as DatafusionResult, ScalarValue};
use datafusion_common::{plan_err, Result as DataFusionResult, ScalarValue};
use datafusion_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};

pub fn check_args(args: &[DataType], fn_name: &str) -> DataFusionResult<()> {
if args.len() < 2 {
return plan_err!("The `{fn_name}` function requires two or more arguments.");
}
args[1..]
.iter()
.enumerate()
.map(|(index, arg)| match arg {
DataType::Utf8 | DataType::UInt64 | DataType::Int64 => Ok(()),
_ => plan_err!(
"Unexpected argument type to `{fn_name}` at position {}, expected string or int.",
index + 2
),
})
.collect()
}

#[derive(Debug)]
pub enum JsonPath<'s> {
Key(&'s str),
Index(usize),
None,
}

impl<'s> JsonPath<'s> {
pub fn extract_args(args: &'s [ColumnarValue], fn_name: &str) -> DatafusionResult<Vec<Self>> {
pub fn extract_args(args: &'s [ColumnarValue]) -> Vec<Self> {
args[1..]
.iter()
.enumerate()
.map(|(index, arg)| match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => Ok(Self::Key(s)),
ColumnarValue::Scalar(ScalarValue::UInt64(Some(i))) => Ok(Self::Index(*i as usize)),
ColumnarValue::Scalar(ScalarValue::Int64(Some(i))) => Ok(Self::Index(*i as usize)),
_ => exec_err!(
"`{fn_name}`: unexpected argument type at {}, expected string or int",
index + 2
),
.map(|arg| match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => Self::Key(s),
ColumnarValue::Scalar(ScalarValue::UInt64(Some(i))) => Self::Index(*i as usize),
ColumnarValue::Scalar(ScalarValue::Int64(Some(i))) => Self::Index(*i as usize),
_ => Self::None,
})
.collect()
}
Expand All @@ -45,11 +60,11 @@ fn jiter_json_find_step(jiter: &mut Jiter, peek: Peek, path: &[JsonPath]) -> Res
let next_peek = match peek {
Peek::Array => match first {
JsonPath::Index(index) => jiter_array_get(jiter, *index),
JsonPath::Key(_) => Err(GetError),
_ => Err(GetError),
},
Peek::Object => match first {
JsonPath::Key(key) => jiter_object_get(jiter, key),
JsonPath::Index(_) => Err(GetError),
_ => Err(GetError),
},
_ => Err(GetError),
}?;
Expand Down
20 changes: 6 additions & 14 deletions src/json_get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use std::sync::Arc;
use arrow::array::{as_string_array, Array, UnionArray};
use arrow_schema::DataType;
use datafusion_common::arrow::array::ArrayRef;
use datafusion_common::{exec_err, plan_err, Result as DatafusionResult, ScalarValue};
use datafusion_common::{exec_err, Result as DataFusionResult, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::{Jiter, NumberAny, NumberInt, Peek};

use crate::common_get::{jiter_json_find, GetError, JsonPath};
use crate::common_get::{check_args, jiter_json_find, GetError, JsonPath};
use crate::common_macros::make_udf_function;
use crate::common_union::{JsonUnion, JsonUnionField};

Expand Down Expand Up @@ -47,20 +47,12 @@ impl ScalarUDFImpl for JsonGet {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> DatafusionResult<DataType> {
if arg_types.len() < 2 {
return plan_err!("The `json_get` function requires two or more arguments.");
}
match arg_types[0] {
DataType::Utf8 | DataType::UInt64 => Ok(JsonUnion::data_type()),
_ => {
plan_err!("The `json_get_str` function can only accepts Utf8 or UInt64 arguments.")
}
}
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|_| JsonUnion::data_type())
}

fn invoke(&self, args: &[ColumnarValue]) -> DatafusionResult<ColumnarValue> {
let path = JsonPath::extract_args(args, self.name())?;
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
let path = JsonPath::extract_args(args);

match &args[0] {
ColumnarValue::Array(array) => {
Expand Down
18 changes: 6 additions & 12 deletions src/json_get_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ use std::sync::Arc;
use arrow::array::{as_string_array, StringArray};
use arrow_schema::DataType;
use datafusion_common::arrow::array::ArrayRef;
use datafusion_common::{exec_err, plan_err, Result as DatafusionResult, ScalarValue};
use datafusion_common::{exec_err, Result as DataFusionResult, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common_get::{jiter_json_find, GetError, JsonPath};
use crate::common_get::{check_args, jiter_json_find, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -46,18 +46,12 @@ impl ScalarUDFImpl for JsonGet {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> DatafusionResult<DataType> {
if arg_types.len() < 2 {
return plan_err!("The `json_get_str` function requires two or more arguments.");
}
match arg_types[0] {
DataType::Utf8 | DataType::UInt64 | DataType::Int64 => Ok(DataType::Utf8),
_ => plan_err!("The `json_get_str` function can only accepts string or int arguments."),
}
fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
check_args(arg_types, self.name()).map(|_| DataType::Utf8)
}

fn invoke(&self, args: &[ColumnarValue]) -> DatafusionResult<ColumnarValue> {
let path = JsonPath::extract_args(args, self.name())?;
fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
let path = JsonPath::extract_args(args);

match &args[0] {
ColumnarValue::Array(array) => {
Expand Down
29 changes: 26 additions & 3 deletions tests/test_json_get.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use arrow_schema::DataType;
use datafusion::assert_batches_eq;

mod utils;
use utils::run_query;
use utils::{display_val, run_query};

#[tokio::test]
async fn test_json_get_union() {
Expand Down Expand Up @@ -119,6 +120,28 @@ async fn test_json_get_str_path() {
.await
.unwrap();

let expected = ["+---+", "| v |", "+---+", "| x |", "+---+"];
assert_batches_eq!(expected, &batches);
assert_eq!(
display_val(batches).await,
("v".to_string(), DataType::Utf8, "x".to_string())
);
}

#[tokio::test]
async fn test_json_get_str_null() {
let e = run_query(r#"select json_get_str('{}', null)"#).await.unwrap_err();

assert_eq!(
e.to_string(),
"Error during planning: Unexpected argument type to `json_get_str` at position 2, expected string or int."
);
}

#[tokio::test]
async fn test_json_get_one_arg() {
let e = run_query(r#"select json_get('{}')"#).await.unwrap_err();

assert_eq!(
e.to_string(),
"Error during planning: The `json_get` function requires two or more arguments."
);
}
15 changes: 15 additions & 0 deletions tests/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#![allow(dead_code)]
use arrow::datatypes::{DataType, Field, Schema};
use arrow::util::display::{ArrayFormatter, FormatOptions};
use arrow::{array::StringArray, record_batch::RecordBatch};
use std::sync::Arc;

Expand Down Expand Up @@ -45,3 +47,16 @@ pub async fn run_query(sql: &str) -> Result<Vec<RecordBatch>> {
let df = ctx.sql(sql).await?;
df.collect().await
}

pub async fn display_val(batch: Vec<RecordBatch>) -> (String, DataType, String) {
assert_eq!(batch.len(), 1);
let batch = batch.first().unwrap();
assert_eq!(batch.num_rows(), 1);
let schema = batch.schema();
let schema_col = schema.field(0);
let c = batch.column(0);
let options = FormatOptions::default().with_display_error(true);
let f = ArrayFormatter::try_new(c.as_ref(), &options).unwrap();
let repr = f.value(0).try_to_string().unwrap();
(schema_col.name().clone(), schema_col.data_type().clone(), repr)
}

0 comments on commit 279460e

Please sign in to comment.