Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support 'col IN (a, b, c)' type expressions #652

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion kernel/src/engine/arrow_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl TryFrom<&ArrowDataType> for DataType {
ArrowDataType::Date64 => Ok(DataType::DATE),
ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => Ok(DataType::TIMESTAMP_NTZ),
ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz))
if tz.eq_ignore_ascii_case("utc") =>
if tz.eq_ignore_ascii_case("utc") || tz.eq_ignore_ascii_case("+00:00") =>
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data in arrow arrays should always represent a timestamp in UTC, so is this check even necessary?

https://github.com/apache/arrow-rs/blob/af777cd53e56f8382382137b6e08af249c475397/arrow-schema/src/datatype.rs#L179-L182

{
Ok(DataType::TIMESTAMP)
}
Expand Down
162 changes: 160 additions & 2 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::expressions::{
BinaryExpression, BinaryOperator, Expression, Scalar, UnaryExpression, UnaryOperator,
VariadicExpression, VariadicOperator,
};
use crate::predicates::PredicateEvaluatorDefaults;
use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, Schema, SchemaRef, StructField};
use crate::{EngineData, ExpressionEvaluator, ExpressionHandler};

Expand Down Expand Up @@ -280,6 +281,84 @@ fn evaluate_expression(
(ArrowDataType::Decimal256(_, _), Decimal256Type)
}
}
(Column(name), Literal(Scalar::Array(ad))) => {
fn op<T: ArrowPrimitiveType>(
values: &dyn Array,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
values: &dyn Array,
values: ArrayRef,

(avoids the need for .as_ref() at the call site)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the main thing was that we need this to be a reference, otherwise the compiler starts complaining about lifetimes. I did shorten the code at the call-site a bit, hope that works as well.

Copy link
Collaborator

@scovich scovich Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would an ArrayRef (= Arc<dyn Array>) give lifetime problems, sorry?
We can always call as_ref() on it to get a reference that lives at least as long as the arc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not the parameter itself, but the as_primitive cast inside the functions returns a ref, which we then iterate over. This then causes issues with the iterator referencing data owned by the function.

from: fn(T::Native) -> Scalar,
) -> impl Iterator<Item = Option<Scalar>> + '_ {
values.as_primitive::<T>().iter().map(move |v| v.map(from))
}

fn str_op<'a>(
column: impl Iterator<Item = Option<&'a str>> + 'a,
roeap marked this conversation as resolved.
Show resolved Hide resolved
) -> impl Iterator<Item = Option<Scalar>> + 'a {
column.map(|v| v.map(Scalar::from))
}

fn op_in(
inlist: &[Scalar],
values: impl Iterator<Item = Option<Scalar>>,
) -> BooleanArray {
// `v IN (k1, ..., kN)` is logically equivalent to `v = k1 OR ... OR v = kN`, so evaluate
// it as such, ensuring correct handling of NULL inputs (including `Scalar::Null`).
values
.map(|v| {
Some(
PredicateEvaluatorDefaults::finish_eval_variadic(
VariadicOperator::Or,
inlist.iter().map(|k| v.as_ref().map(|vv| vv == k)),
Copy link
Collaborator

@scovich scovich Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't correct -- we need comparisons against Scalar::Null to return None. That's why I had previously recommended using Scalar::partial_cmp instead of ==.

Also, can we not use ? to unwrap the various options here?

Suggested change
inlist.iter().map(|k| v.as_ref().map(|vv| vv == k)),
inlist.iter().map(Some(Scalar::partial_cmp(v?, k?)? == Ordering::Equal)),

Unpacking that -- if the value we search for is NULL, or if the inlist entry is NULL, or if the two values are incomparable, then return None for that pair. Otherwise, return Some boolean indicating whether the values compared equal or not. That automatically covers the various required cases, and also makes us robust to any type mismatches that might creep in.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: If we wanted to be a tad more efficient, we could also unpack v outside the inner loop:

values.map(|v| {
    let v = v?;
    PredicateEvaluatorDefaults::finish_eval_variadic(...)
})

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm -- empty in-lists pose a corner case with respect to unpacking v:

NULL IN ()

Operator OR with zero inputs normally produces FALSE (which is correct if you stop to think about it) -- but unpacking a NULL v first makes the operator return NULL instead (which is also correct if you squint, because NULL input always produces NULL output).

Unfortunately, the only clear docs I could find -- https://spark.apache.org/docs/3.5.1/sql-ref-null-semantics.html#innot-in-subquery- -- are also ambiguous:

Conceptually a IN expression is semantically equivalent to a set of equality condition separated by a disjunctive operator (OR).

... suggests FALSE while

UNKNOWN is returned when the value is NULL

... suggests NULL

The difference matters for NOT IN, because NULL NOT IN () would either return TRUE (keep rows) or NULL (drop row).

Copy link
Collaborator

@scovich scovich Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: SQL engines normally forbid statically empty in-list but do not forbid subqueries from producing an empty result.

I tried the following expression on three engines (sqlite, mysql, postgres):

SELECT 1 WHERE NULL NOT IN (SELECT 1 WHERE FALSE)

And all three returned 1. So OR semantics prevail, and we must NOT unpack v outside the loop.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hoping I now considered all your comments, which essentially means going with your original version.

false,
)
// None is returned when no dominant value (true) is found and there is at least one NULL
// In th case of IN, this is equivalent to false
roeap marked this conversation as resolved.
Show resolved Hide resolved
.unwrap_or(false),
)
})
.collect()
}

#[allow(deprecated)]
let inlist = ad.array_elements();
let column = extract_column(batch, name)?;
let data_type = ad
.array_type()
.element_type()
.as_primitive_opt()
.ok_or_else(|| {
Error::invalid_expression(format!(
"IN only supports array literals with primitive elements, got: '{:?}'",
ad.array_type().element_type()
))
})?;

// safety: as_* methods on arrow arrays can panic, but we checked the data type before applying.
let arr = match (column.data_type(), data_type) {
(ArrowDataType::Utf8, PrimitiveType::String) => op_in(inlist, str_op(column.as_string::<i32>().iter())),
(ArrowDataType::LargeUtf8, PrimitiveType::String) => op_in(inlist, str_op(column.as_string::<i64>().iter())),
(ArrowDataType::Utf8View, PrimitiveType::String) => op_in(inlist, str_op(column.as_string_view().iter())),
(ArrowDataType::Int8, PrimitiveType::Byte) => op_in(inlist,op::<Int8Type>( column.as_ref(), Scalar::from)),
(ArrowDataType::Int16, PrimitiveType::Short) => op_in(inlist,op::<Int16Type>(column.as_ref(), Scalar::from)),
(ArrowDataType::Int32, PrimitiveType::Integer) => op_in(inlist,op::<Int32Type>(column.as_ref(), Scalar::from)),
(ArrowDataType::Int64, PrimitiveType::Long) => op_in(inlist,op::<Int64Type>(column.as_ref(), Scalar::from)),
(ArrowDataType::Float32, PrimitiveType::Float) => op_in(inlist,op::<Float32Type>(column.as_ref(), Scalar::from)),
(ArrowDataType::Float64, PrimitiveType::Double) => op_in(inlist,op::<Float64Type>(column.as_ref(), Scalar::from)),
(ArrowDataType::Date32, PrimitiveType::Date) => op_in(inlist,op::<Date32Type>(column.as_ref(), Scalar::Date)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all a lot longer than 100 chars... why doesn't the fmt check blow up??

(
ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(_)),
PrimitiveType::Timestamp,
) => op_in(inlist,op::<TimestampMicrosecondType>(column.as_ref(), Scalar::Timestamp)),
(
ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
PrimitiveType::TimestampNtz,
) => op_in(inlist,op::<TimestampMicrosecondType>(column.as_ref(), Scalar::TimestampNtz)),
(l, r) => {
return Err(Error::invalid_expression(format!(
"Cannot check if value of type '{l}' is contained in array with values of type '{r}'"
)))
}
};
Ok(Arc::new(arr))
}
(Literal(lit), Literal(Scalar::Array(ad))) => {
#[allow(deprecated)]
let exists = ad.array_elements().contains(lit);
Expand Down Expand Up @@ -382,8 +461,8 @@ fn new_field_with_metadata(

// A helper that is a wrapper over `transform_field_and_col`. This will take apart the passed struct
// and use that method to transform each column and then put the struct back together. Target types
// and names for each column should be passed in `target_types_and_names`. The number of elements in
// the `target_types_and_names` iterator _must_ be the same as the number of columns in
// and names for each column should be passed in `target_fields`. The number of elements in
// the `target_fields` iterator _must_ be the same as the number of columns in
// `struct_array`. The transformation is ordinal. That is, the order of fields in `target_fields`
// _must_ match the order of the columns in `struct_array`.
fn transform_struct(
Expand Down Expand Up @@ -692,6 +771,85 @@ mod tests {
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
roeap marked this conversation as resolved.
Show resolved Hide resolved
fn test_column_in_array() {
Copy link
Collaborator

@scovich scovich Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relating to #652 (comment) and #652 (comment) -- we don't seem to have any tests that cover NULL value semantics?

1 IN (1, NULL) -- TRUE
1 IN (2, NULL) -- NULL
NULL IN (1, 2) -- NULL

1 NOT IN (1, NULL) -- FALSE (**)
1 NOT IN (2, NULL) -- NULL
NULL NOT IN (1, 2) -- NULL

(**) NOTE from https://spark.apache.org/docs/3.5.1/sql-ref-null-semantics.html#innot-in-subquery-:

NOT IN always returns UNKNOWN when the list contains NULL, regardless of the input value. This is because IN returns UNKNOWN if the value is not in the list containing NULL, and because NOT UNKNOWN is again UNKNOWN.

IMO, that explanation is confusing and factually incorrect. If we explain it in terms of NOT(OR):

1 NOT IN (1, NULL)
= NOT(1 IN (1, NULL))
= NOT(1 = 1 OR 1 = NULL)
= NOT(1 = 1) AND NOT(1 = NULL)
= 1 != 1 AND 1 != NULL
= FALSE AND NULL
= FALSE

As additional support for my claim: sqlite, postgres, and mysql all return FALSE (not NULL) for that expression.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some test according to the cases mentioned above, hopefully covering all cases. This uncovered some cases where we were not handling NULLs correctly in the other in-list branches, mainly b/c the arrow kernels don't seem to be adhering to the SQL NULL semantics.

In addition to the engines above, I also tried duckdb and datafusion, which also support @scovich's claim.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, is this something worth upstreaming to arrow-rs similar to the *_kleene variants for other kernels?

let values = Int32Array::from(vec![0, 1, 2, 3]);
let field = Arc::new(Field::new("item", DataType::Int32, true));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::Integer.into(), false),
[Scalar::Integer(1), Scalar::Integer(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();

let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
let in_expected = BooleanArray::from(vec![false, true, false, true]);
assert_eq!(in_result.as_ref(), &in_expected);

let not_in_op = Expression::binary(BinaryOperator::NotIn, column_expr!("item"), rhs);
let not_in_result =
evaluate_expression(&not_in_op, &batch, Some(&crate::schema::DataType::BOOLEAN))
.unwrap();
let not_in_expected = BooleanArray::from(vec![true, false, true, false]);
assert_eq!(not_in_result.as_ref(), &not_in_expected);

let in_expected = BooleanArray::from(vec![false, true, false, true]);

// Date arrays
let values = Date32Array::from(vec![0, 1, 2, 3]);
let field = Arc::new(Field::new("item", DataType::Date32, true));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::Date.into(), false),
[Scalar::Date(1), Scalar::Date(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
assert_eq!(in_result.as_ref(), &in_expected);

// Timestamp arrays
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]).with_timezone("UTC");
let field = Arc::new(Field::new(
"item",
(&crate::schema::DataType::TIMESTAMP).try_into().unwrap(),
true,
));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::Timestamp.into(), false),
[Scalar::Timestamp(1), Scalar::Timestamp(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
assert_eq!(in_result.as_ref(), &in_expected);

// Timestamp NTZ arrays
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]);
let field = Arc::new(Field::new(
"item",
(&crate::schema::DataType::TIMESTAMP_NTZ)
.try_into()
.unwrap(),
true,
));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::TimestampNtz.into(), false),
[Scalar::TimestampNtz(1), Scalar::TimestampNtz(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_extract_column() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down
72 changes: 71 additions & 1 deletion kernel/src/expressions/scalars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ impl StructData {

/// A single value, which can be null. Used for representing literal values
/// in [Expressions][crate::expressions::Expression].
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone)]
pub enum Scalar {
/// 32bit integer
Integer(i32),
Expand Down Expand Up @@ -224,6 +224,48 @@ impl Display for Scalar {
}
}

impl PartialEq<Scalar> for Scalar {
fn eq(&self, other: &Self) -> bool {
use Scalar::*;
// NOTE: We intentionally do two match arms for each variant to avoid a catch-all, so
// that new variants trigger compilation failures instead of being silently ignored.
match (self, other) {
roeap marked this conversation as resolved.
Show resolved Hide resolved
(Integer(a), Integer(b)) => a == b,
(Integer(_), _) => false,
(Long(a), Long(b)) => a == b,
(Long(_), _) => false,
(Short(a), Short(b)) => a == b,
(Short(_), _) => false,
(Byte(a), Byte(b)) => a == b,
(Byte(_), _) => false,
(Float(a), Float(b)) => a == b,
(Float(_), _) => false,
(Double(a), Double(b)) => a == b,
(Double(_), _) => false,
(String(a), String(b)) => a == b,
(String(_), _) => false,
(Boolean(a), Boolean(b)) => a == b,
(Boolean(_), _) => false,
(Timestamp(a), Timestamp(b)) => a == b,
(Timestamp(_), _) => false,
(TimestampNtz(a), TimestampNtz(b)) => a == b,
(TimestampNtz(_), _) => false,
(Date(a), Date(b)) => a == b,
(Date(_), _) => false,
(Binary(a), Binary(b)) => a == b,
(Binary(_), _) => false,
(Decimal(a, _, _), Decimal(b, _, _)) => a == b,
(Decimal(_, _, _), _) => false,
(Struct(a), Struct(b)) => a == b,
(Struct(_), _) => false,
(Array(a), Array(b)) => a == b,
(Array(_), _) => false,
(Null(_), Null(_)) => false, // NOTE: NULL values are incomparable by definition
(Null(_), _) => false,
}
}
}

impl PartialOrd for Scalar {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
use Scalar::*;
Expand Down Expand Up @@ -585,6 +627,7 @@ mod tests {
assert_eq!(&format!("{}", column_op), "3.1415927 IN Column(item)");
assert_eq!(&format!("{}", column_not_op), "'Cool' NOT IN Column(item)");
}

#[test]
fn test_timestamp_parse() {
let assert_timestamp_eq = |scalar_string, micros| {
Expand All @@ -599,6 +642,7 @@ mod tests {
assert_timestamp_eq("2011-01-11 13:06:07.123456", 1294751167123456);
assert_timestamp_eq("1970-01-01 00:00:00", 0);
}

#[test]
fn test_timestamp_ntz_parse() {
let assert_timestamp_eq = |scalar_string, micros| {
Expand Down Expand Up @@ -627,4 +671,30 @@ mod tests {
let p_type = PrimitiveType::Timestamp;
assert_timestamp_fails(&p_type, "1971-07-22");
}

#[test]
fn test_partial_cmp() {
let a = Scalar::Integer(1);
let b = Scalar::Integer(2);
roeap marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
assert_eq!(b.partial_cmp(&a), Some(Ordering::Greater));
assert_eq!(a.partial_cmp(&a), Some(Ordering::Equal));
assert_eq!(b.partial_cmp(&b), Some(Ordering::Equal));

// assert that NULL values are incomparable
let null = Scalar::Null(DataType::INTEGER);
assert_eq!(null.partial_cmp(&null), None);
}

#[test]
fn test_partial_eq() {
let a = Scalar::Integer(1);
let b = Scalar::Integer(2);
assert!(!a.eq(&b));
assert!(a.eq(&a));

// assert that NULL values are incomparable
let null = Scalar::Null(DataType::INTEGER);
assert!(!null.eq(&null));
}
}
Loading