Skip to content

Commit

Permalink
simplify further
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 27, 2024
1 parent 3fecc88 commit da247df
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 47 deletions.
61 changes: 20 additions & 41 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::str::Utf8Error;
use std::sync::Arc;

use datafusion::arrow::array::{
Array, ArrayRef, AsArray, BooleanArray, DictionaryArray, Float64Array, Int64Array, LargeStringArray,
PrimitiveArray, StringArray, StringViewArray, UInt64Array, UnionArray,
Array, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, StringArray,
StringViewArray, UInt64Array, UnionArray,
};
use datafusion::arrow::compute::take;
use datafusion::arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType};
Expand All @@ -12,9 +12,7 @@ use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarV
use datafusion::logical_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};

use crate::common_union::{
is_json_union, json_from_union_scalar, nested_json_array, JsonUnion, JsonUnionField, TYPE_ID_NULL,
};
use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL};

/// General implementation of `ScalarUDFImpl::return_type`.
///
Expand Down Expand Up @@ -103,29 +101,7 @@ impl<'s> JsonPath<'s> {
}
}

/// Same as `FromIterator` but we defined within the crate so we can custom as we wish,
/// e.g. for `ListArray` with `Vec<String>`.
pub(crate) trait JiterFromIterator<I>: Sized {
fn jiter_from_iter<T: IntoIterator<Item = I>>(iter: T) -> Self;
}

macro_rules! impl_jiter_from_iterator {
($collect:ty, $item:ty) => {
impl JiterFromIterator<$item> for $collect {
fn jiter_from_iter<T: IntoIterator<Item = $item>>(iter: T) -> Self {
<$collect>::from_iter(iter)
}
}
};
}
impl_jiter_from_iterator!(Int64Array, Option<i64>);
impl_jiter_from_iterator!(UInt64Array, Option<u64>);
impl_jiter_from_iterator!(Float64Array, Option<f64>);
impl_jiter_from_iterator!(StringArray, Option<String>);
impl_jiter_from_iterator!(BooleanArray, Option<bool>);
impl_jiter_from_iterator!(JsonUnion, Option<JsonUnionField>);

pub(crate) fn invoke<C: JiterFromIterator<Option<I>> + 'static, I>(
pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
args: &[ColumnarValue],
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
Expand Down Expand Up @@ -162,7 +138,7 @@ pub(crate) fn invoke<C: JiterFromIterator<Option<I>> + 'static, I>(
}
}

fn invoke_array<C: JiterFromIterator<Option<I>> + 'static, I>(
fn invoke_array<C: FromIterator<Option<I>> + 'static, I>(
json_array: &ArrayRef,
needle_array: &ArrayRef,
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
Expand Down Expand Up @@ -192,7 +168,7 @@ fn invoke_array<C: JiterFromIterator<Option<I>> + 'static, I>(
}
}

fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: JiterFromIterator<Option<I>> + 'static, I>(
fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
json_array: &ArrayRef,
path_array: P,
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
Expand Down Expand Up @@ -221,18 +197,21 @@ fn zip_apply<'a, P: Iterator<Item = Option<JsonPath<'a>>>, C: JiterFromIterator<
to_array(c)
}

fn zip_apply_iter<'a, 'j, P: Iterator<Item = Option<JsonPath<'a>>>, C: JiterFromIterator<Option<I>> + 'static, I>(
fn zip_apply_iter<'a, 'j, P: Iterator<Item = Option<JsonPath<'a>>>, C: FromIterator<Option<I>> + 'static, I>(
json_iter: impl Iterator<Item = Option<&'j str>>,
path_array: P,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> C {
C::jiter_from_iter(json_iter.zip(path_array).map(|(opt_json, opt_path)| {
if let Some(path) = opt_path {
jiter_find(opt_json, &[path]).ok()
} else {
None
}
}))
json_iter
.zip(path_array)
.map(|(opt_json, opt_path)| {
if let Some(path) = opt_path {
jiter_find(opt_json, &[path]).ok()
} else {
None
}
})
.collect::<C>()
}

fn invoke_scalar<I>(
Expand All @@ -259,7 +238,7 @@ fn invoke_scalar<I>(
}
}

fn scalar_apply<C: JiterFromIterator<Option<I>>, I>(
fn scalar_apply<C: FromIterator<Option<I>>, I>(
json_array: &ArrayRef,
path: &[JsonPath],
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
Expand Down Expand Up @@ -317,12 +296,12 @@ fn is_object_lookup(path: &[JsonPath]) -> bool {
}
}

fn scalar_apply_iter<'j, C: JiterFromIterator<Option<I>>, I>(
fn scalar_apply_iter<'j, C: FromIterator<Option<I>>, I>(
json_iter: impl Iterator<Item = Option<&'j str>>,
path: &[JsonPath],
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
) -> C {
C::jiter_from_iter(json_iter.map(|opt_json| jiter_find(opt_json, path).ok()))
json_iter.map(|opt_json| jiter_find(opt_json, path).ok()).collect::<C>()
}

pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> {
Expand Down
15 changes: 9 additions & 6 deletions src/json_object_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use datafusion::common::{Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use jiter::Peek;

use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JiterFromIterator, JsonPath};
use crate::common::{get_err, invoke, jiter_json_find, return_type_check, GetError, JsonPath};
use crate::common_macros::make_udf_function;

make_udf_function!(
Expand Down Expand Up @@ -54,10 +54,10 @@ impl ScalarUDFImpl for JsonObjectKeys {
}

fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult<ColumnarValue> {
invoke::<ListArray, Vec<String>>(
invoke::<ListArrayWrapper, Vec<String>>(
args,
jiter_json_object_keys,
|c| Ok(Arc::new(c) as ArrayRef),
|w| Ok(Arc::new(w.0) as ArrayRef),
keys_to_scalar,
true,
)
Expand All @@ -68,8 +68,11 @@ impl ScalarUDFImpl for JsonObjectKeys {
}
}

impl JiterFromIterator<Option<Vec<String>>> for ListArray {
fn jiter_from_iter<I: IntoIterator<Item = Option<Vec<String>>>>(iter: I) -> Self {
#[derive(Debug)]
struct ListArrayWrapper(ListArray);

impl FromIterator<Option<Vec<String>>> for ListArrayWrapper {
fn from_iter<I: IntoIterator<Item = Option<Vec<String>>>>(iter: I) -> Self {
let values_builder = StringBuilder::new();
let mut builder = ListBuilder::new(values_builder);
for opt_keys in iter {
Expand All @@ -82,7 +85,7 @@ impl JiterFromIterator<Option<Vec<String>>> for ListArray {
builder.append(false);
}
}
builder.finish()
Self(builder.finish())
}
}

Expand Down

0 comments on commit da247df

Please sign in to comment.