Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support 'col IN (a, b, c)' type expressions #652

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion kernel/src/engine/arrow_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl TryFrom<&ArrowDataType> for DataType {
ArrowDataType::Date64 => Ok(DataType::DATE),
ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => Ok(DataType::TIMESTAMP_NTZ),
ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz))
if tz.eq_ignore_ascii_case("utc") =>
if tz.eq_ignore_ascii_case("utc") || tz.eq_ignore_ascii_case("+00:00") =>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data in arrow arrays should always represent a timestamp in UTC, so is this check even necessary?

https://github.com/apache/arrow-rs/blob/af777cd53e56f8382382137b6e08af249c475397/arrow-schema/src/datatype.rs#L179-L182

{
Ok(DataType::TIMESTAMP)
}
Expand Down
158 changes: 156 additions & 2 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,81 @@ fn evaluate_expression(
(ArrowDataType::Decimal256(_, _), Decimal256Type)
}
}
(Column(name), Literal(Scalar::Array(ad))) => {
use crate::expressions::ArrayData;

let column = extract_column(batch, name)?;
let data_type = ad
.array_type()
.element_type()
.as_primitive_opt()
.ok_or_else(|| {
Error::invalid_expression(format!(
"IN only supports array literals with primitive elements, got: '{:?}'",
ad.array_type().element_type()
))
})?;

fn op(
col: impl Iterator<Item = Option<impl Into<Scalar>>>,
roeap marked this conversation as resolved.
Show resolved Hide resolved
ad: &ArrayData,
) -> BooleanArray {
#[allow(deprecated)]
let res = col.map(|val| val.map(|v| ad.array_elements().contains(&v.into())));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this handles NULL values correctly? See e.g. https://spark.apache.org/docs/3.5.1/sql-ref-null-semantics.html#innot-in-subquery-:

  • TRUE is returned when the non-NULL value in question is found in the list
  • FALSE is returned when the non-NULL value is not found in the list and the list does not contain NULL values
  • UNKNOWN is returned when the value is NULL, or the non-NULL value is not found in the list and the list contains at least one NULL value

I think, instead of calling contains, you could borrow the code from PredicateEvaluatorDefaults::finish_eval_variadic, with true as the "dominator" value.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think you could just invoke that method directly, with a properly crafted iterator?

// `v IN (k1, ..., kN)` is logically equivalent to `v = k1 OR ... OR v = kN`, so evaluate
// it as such, ensuring correct handling of NULL inputs (including `Scalar::Null`).
col.map(|v| {
    PredicateEvaluatorDefaults::finish_eval_variadic(
        VariadicOperator::Or, 
        inlist.iter().map(Some(Scalar::partial_cmp(v?, k?)? == Ordering::Equal)),
        false,
    )
})

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was I correct in thinking that None - no dominant value, but found Null - should just be false in this case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: We actually have a lurking bug -- Scalar derives PartialEq which will allow two Scalar::Null to compare equal. But SQL semantics dictate that NULL doesn't compare equal to anything -- not even itself.

Our manual impl of PartialOrd for Scalar does this correctly, but it breaks the rules for PartialEq:

If PartialOrd or Ord are also implemented for Self and Rhs, their methods must also be consistent with PartialEq (see the documentation of those traits for the exact requirements). It’s easy to accidentally make them disagree by deriving some of the traits and manually implementing others.

Looks like we'll need to define a manual impl PartialEq for Scalar that follows the same approach.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is indeed not covered. Added an implementation for PartialEq that mirrors PartialOrd.

BooleanArray::from_iter(res)
}
roeap marked this conversation as resolved.
Show resolved Hide resolved

// safety: as_* methods on arrow arrays can panic, but we checked the data type before applying.
let arr: BooleanArray = match (column.data_type(), data_type) {
roeap marked this conversation as resolved.
Show resolved Hide resolved
(ArrowDataType::Utf8, PrimitiveType::String) => op(column.as_string::<i32>().iter(), ad),
(ArrowDataType::LargeUtf8, PrimitiveType::String) => op(column.as_string::<i64>().iter(), ad),
(ArrowDataType::Utf8View, PrimitiveType::String) => op(column.as_string_view().iter(), ad),
(ArrowDataType::Int8, PrimitiveType::Byte) => op(column.as_primitive::<Int8Type>().iter(), ad),
(ArrowDataType::Int16, PrimitiveType::Short) => op(column.as_primitive::<Int16Type>().iter(), ad),
(ArrowDataType::Int32, PrimitiveType::Integer) => op(column.as_primitive::<Int32Type>().iter(), ad),
(ArrowDataType::Int64, PrimitiveType::Long) => op(column.as_primitive::<Int64Type>().iter(), ad),
(ArrowDataType::Float32, PrimitiveType::Float) => op(column.as_primitive::<Float32Type>().iter(), ad),
(ArrowDataType::Float64, PrimitiveType::Double) => op(column.as_primitive::<Float64Type>().iter(), ad),
(ArrowDataType::Date32, PrimitiveType::Date) => {
#[allow(deprecated)]
let res = column
.as_primitive::<Date32Type>()
.iter()
.map(|val| val.map(|v| ad.array_elements().contains(&Scalar::Date(v))));
BooleanArray::from_iter(res)
}
(
ArrowDataType::Timestamp(TimeUnit::Microsecond, unit),
kt @ PrimitiveType::Timestamp | kt @ PrimitiveType::TimestampNtz,
) => {
let res = column.as_primitive::<TimestampMicrosecondType>().iter();
match (unit, kt) {
// regardless of the time zone stored in the timestamp, the underlying value is always in UTC
(Some(_), PrimitiveType::Timestamp) => {
BooleanArray::from_iter(res.map(|val| {
#[allow(deprecated)]
val.map(|v| ad.array_elements().contains(&Scalar::Timestamp(v)))
}))
}
(None, PrimitiveType::TimestampNtz) => {
BooleanArray::from_iter(res.map(|val| {
val.map(|v| {
#[allow(deprecated)]
ad.array_elements().contains(&Scalar::TimestampNtz(v))
})
}))
}
_ => unreachable!(),
}
}
scovich marked this conversation as resolved.
Show resolved Hide resolved
(l, r) => {
return Err(Error::invalid_expression(format!(
"Cannot check if value of type '{l}' is contained in array with values of type '{r}'"
)))
}
};
Ok(Arc::new(arr))
}
(Literal(lit), Literal(Scalar::Array(ad))) => {
#[allow(deprecated)]
let exists = ad.array_elements().contains(lit);
Expand Down Expand Up @@ -382,8 +457,8 @@ fn new_field_with_metadata(

// A helper that is a wrapper over `transform_field_and_col`. This will take apart the passed struct
// and use that method to transform each column and then put the struct back together. Target types
// and names for each column should be passed in `target_types_and_names`. The number of elements in
// the `target_types_and_names` iterator _must_ be the same as the number of columns in
// and names for each column should be passed in `target_fields`. The number of elements in
// the `target_fields` iterator _must_ be the same as the number of columns in
// `struct_array`. The transformation is ordinal. That is, the order of fields in `target_fields`
// _must_ match the order of the columns in `struct_array`.
fn transform_struct(
Expand Down Expand Up @@ -692,6 +767,85 @@ mod tests {
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
roeap marked this conversation as resolved.
Show resolved Hide resolved
fn test_column_in_array() {
Copy link
Collaborator

@scovich scovich Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relating to #652 (comment) and #652 (comment) -- we don't seem to have any tests that cover NULL value semantics?

1 IN (1, NULL) -- TRUE
1 IN (2, NULL) -- NULL
NULL IN (1, 2) -- NULL

1 NOT IN (1, NULL) -- FALSE (**)
1 NOT IN (2, NULL) -- NULL
NULL NOT IN (1, 2) -- NULL

(**) NOTE from https://spark.apache.org/docs/3.5.1/sql-ref-null-semantics.html#innot-in-subquery-:

NOT IN always returns UNKNOWN when the list contains NULL, regardless of the input value. This is because IN returns UNKNOWN if the value is not in the list containing NULL, and because NOT UNKNOWN is again UNKNOWN.

IMO, that explanation is confusing and factually incorrect. If we explain it in terms of NOT(OR):

1 NOT IN (1, NULL)
= NOT(1 IN (1, NULL))
= NOT(1 = 1 OR 1 = NULL)
= NOT(1 = 1) AND NOT(1 = NULL)
= 1 != 1 AND 1 != NULL
= FALSE AND NULL
= FALSE

As additional support for my claim: sqlite, postgres, and mysql all return FALSE (not NULL) for that expression.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some test according to the cases mentioned above, hopefully covering all cases. This uncovered some cases where we were not handling NULLs correctly in the other in-list branches, mainly b/c the arrow kernels don't seem to be adhering to the SQL NULL semantics.

In addition to the engines above, I also tried duckdb and datafusion, which also support @scovich's claim.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, is this something worth upstreaming to arrow-rs similar to the *_kleene variants for other kernels?

let values = Int32Array::from(vec![0, 1, 2, 3]);
let field = Arc::new(Field::new("item", DataType::Int32, true));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::Integer.into(), false),
[Scalar::Integer(1), Scalar::Integer(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();

let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
let in_expected = BooleanArray::from(vec![false, true, false, true]);
assert_eq!(in_result.as_ref(), &in_expected);

let not_in_op = Expression::binary(BinaryOperator::NotIn, column_expr!("item"), rhs);
let not_in_result =
evaluate_expression(&not_in_op, &batch, Some(&crate::schema::DataType::BOOLEAN))
.unwrap();
let not_in_expected = BooleanArray::from(vec![true, false, true, false]);
assert_eq!(not_in_result.as_ref(), &not_in_expected);

let in_expected = BooleanArray::from(vec![false, true, false, true]);

// Date arrays
let values = Date32Array::from(vec![0, 1, 2, 3]);
let field = Arc::new(Field::new("item", DataType::Date32, true));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::Date.into(), false),
[Scalar::Date(1), Scalar::Date(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
assert_eq!(in_result.as_ref(), &in_expected);

// Timestamp arrays
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]).with_timezone("UTC");
let field = Arc::new(Field::new(
"item",
(&crate::schema::DataType::TIMESTAMP).try_into().unwrap(),
true,
));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::Timestamp.into(), false),
[Scalar::Timestamp(1), Scalar::Timestamp(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
assert_eq!(in_result.as_ref(), &in_expected);

// Timestamp NTZ arrays
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]);
let field = Arc::new(Field::new(
"item",
(&crate::schema::DataType::TIMESTAMP_NTZ)
.try_into()
.unwrap(),
true,
));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::TimestampNtz.into(), false),
[Scalar::TimestampNtz(1), Scalar::TimestampNtz(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_extract_column() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down
Loading