-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support 'col IN (a, b, c)' type expressions #652
base: main
Are you sure you want to change the base?
Changes from 5 commits
28a5648
d6e3730
848ef11
290d65d
def21c1
6b959eb
4e3e92c
ba3b4e9
233c0e8
6c813e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -32,6 +32,7 @@ use crate::expressions::{ | |||||
BinaryExpression, BinaryOperator, Expression, Scalar, UnaryExpression, UnaryOperator, | ||||||
VariadicExpression, VariadicOperator, | ||||||
}; | ||||||
use crate::predicates::PredicateEvaluatorDefaults; | ||||||
use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, Schema, SchemaRef, StructField}; | ||||||
use crate::{EngineData, ExpressionEvaluator, ExpressionHandler}; | ||||||
|
||||||
|
@@ -280,6 +281,84 @@ fn evaluate_expression( | |||||
(ArrowDataType::Decimal256(_, _), Decimal256Type) | ||||||
} | ||||||
} | ||||||
(Column(name), Literal(Scalar::Array(ad))) => { | ||||||
fn op<T: ArrowPrimitiveType>( | ||||||
values: &dyn Array, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(avoids the need for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think the main thing was that we need this to be a reference, otherwise the compiler starts complaining about lifetimes. I did shorten the code at the call-site a bit, hope that works as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not the parameter itself, but the |
||||||
from: fn(T::Native) -> Scalar, | ||||||
) -> impl Iterator<Item = Option<Scalar>> + '_ { | ||||||
values.as_primitive::<T>().iter().map(move |v| v.map(from)) | ||||||
} | ||||||
|
||||||
fn str_op<'a>( | ||||||
column: impl Iterator<Item = Option<&'a str>> + 'a, | ||||||
roeap marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
) -> impl Iterator<Item = Option<Scalar>> + 'a { | ||||||
column.map(|v| v.map(Scalar::from)) | ||||||
} | ||||||
|
||||||
fn op_in( | ||||||
inlist: &[Scalar], | ||||||
values: impl Iterator<Item = Option<Scalar>>, | ||||||
) -> BooleanArray { | ||||||
// `v IN (k1, ..., kN)` is logically equivalent to `v = k1 OR ... OR v = kN`, so evaluate | ||||||
// it as such, ensuring correct handling of NULL inputs (including `Scalar::Null`). | ||||||
values | ||||||
.map(|v| { | ||||||
Some( | ||||||
PredicateEvaluatorDefaults::finish_eval_variadic( | ||||||
VariadicOperator::Or, | ||||||
inlist.iter().map(|k| v.as_ref().map(|vv| vv == k)), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't correct -- we need comparisons against Also, can we not use
Suggested change
Unpacking that -- if the value we search for is NULL, or if the inlist entry is NULL, or if the two values are incomparable, then return None for that pair. Otherwise, return Some boolean indicating whether the values compared equal or not. That automatically covers the various required cases, and also makes us robust to any type mismatches that might creep in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: If we wanted to be a tad more efficient, we could also unpack values.map(|v| {
let v = v?;
PredicateEvaluatorDefaults::finish_eval_variadic(...)
}) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm -- empty in-lists pose a corner case with respect to unpacking NULL IN () Operator OR with zero inputs normally produces FALSE (which is correct if you stop to think about it) -- but unpacking a NULL Unfortunately, the only clear docs I could find -- https://spark.apache.org/docs/3.5.1/sql-ref-null-semantics.html#innot-in-subquery- -- are also ambiguous:
... suggests FALSE while
... suggests NULL The difference matters for NOT IN, because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NOTE: SQL engines normally forbid statically empty in-list but do not forbid subqueries from producing an empty result. I tried the following expression on three engines (sqlite, mysql, postgres): SELECT 1 WHERE NULL NOT IN (SELECT 1 WHERE FALSE) And all three returned There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hoping I now considered all your comments, which essentially means going with your original version. |
||||||
false, | ||||||
) | ||||||
// None is returned when no dominant value (true) is found and there is at least one NULL | ||||||
// In th case of IN, this is equivalent to false | ||||||
roeap marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
.unwrap_or(false), | ||||||
) | ||||||
}) | ||||||
.collect() | ||||||
} | ||||||
|
||||||
#[allow(deprecated)] | ||||||
let inlist = ad.array_elements(); | ||||||
let column = extract_column(batch, name)?; | ||||||
let data_type = ad | ||||||
.array_type() | ||||||
.element_type() | ||||||
.as_primitive_opt() | ||||||
.ok_or_else(|| { | ||||||
Error::invalid_expression(format!( | ||||||
"IN only supports array literals with primitive elements, got: '{:?}'", | ||||||
ad.array_type().element_type() | ||||||
)) | ||||||
})?; | ||||||
|
||||||
// safety: as_* methods on arrow arrays can panic, but we checked the data type before applying. | ||||||
let arr = match (column.data_type(), data_type) { | ||||||
(ArrowDataType::Utf8, PrimitiveType::String) => op_in(inlist, str_op(column.as_string::<i32>().iter())), | ||||||
(ArrowDataType::LargeUtf8, PrimitiveType::String) => op_in(inlist, str_op(column.as_string::<i64>().iter())), | ||||||
(ArrowDataType::Utf8View, PrimitiveType::String) => op_in(inlist, str_op(column.as_string_view().iter())), | ||||||
(ArrowDataType::Int8, PrimitiveType::Byte) => op_in(inlist,op::<Int8Type>( column.as_ref(), Scalar::from)), | ||||||
(ArrowDataType::Int16, PrimitiveType::Short) => op_in(inlist,op::<Int16Type>(column.as_ref(), Scalar::from)), | ||||||
(ArrowDataType::Int32, PrimitiveType::Integer) => op_in(inlist,op::<Int32Type>(column.as_ref(), Scalar::from)), | ||||||
(ArrowDataType::Int64, PrimitiveType::Long) => op_in(inlist,op::<Int64Type>(column.as_ref(), Scalar::from)), | ||||||
(ArrowDataType::Float32, PrimitiveType::Float) => op_in(inlist,op::<Float32Type>(column.as_ref(), Scalar::from)), | ||||||
(ArrowDataType::Float64, PrimitiveType::Double) => op_in(inlist,op::<Float64Type>(column.as_ref(), Scalar::from)), | ||||||
(ArrowDataType::Date32, PrimitiveType::Date) => op_in(inlist,op::<Date32Type>(column.as_ref(), Scalar::Date)), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are all a lot longer than 100 chars... why doesn't the fmt check blow up?? |
||||||
( | ||||||
ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(_)), | ||||||
PrimitiveType::Timestamp, | ||||||
) => op_in(inlist,op::<TimestampMicrosecondType>(column.as_ref(), Scalar::Timestamp)), | ||||||
( | ||||||
ArrowDataType::Timestamp(TimeUnit::Microsecond, None), | ||||||
PrimitiveType::TimestampNtz, | ||||||
) => op_in(inlist,op::<TimestampMicrosecondType>(column.as_ref(), Scalar::TimestampNtz)), | ||||||
(l, r) => { | ||||||
return Err(Error::invalid_expression(format!( | ||||||
"Cannot check if value of type '{l}' is contained in array with values of type '{r}'" | ||||||
))) | ||||||
} | ||||||
}; | ||||||
Ok(Arc::new(arr)) | ||||||
} | ||||||
(Literal(lit), Literal(Scalar::Array(ad))) => { | ||||||
#[allow(deprecated)] | ||||||
let exists = ad.array_elements().contains(lit); | ||||||
|
@@ -382,8 +461,8 @@ fn new_field_with_metadata( | |||||
|
||||||
// A helper that is a wrapper over `transform_field_and_col`. This will take apart the passed struct | ||||||
// and use that method to transform each column and then put the struct back together. Target types | ||||||
// and names for each column should be passed in `target_types_and_names`. The number of elements in | ||||||
// the `target_types_and_names` iterator _must_ be the same as the number of columns in | ||||||
// and names for each column should be passed in `target_fields`. The number of elements in | ||||||
// the `target_fields` iterator _must_ be the same as the number of columns in | ||||||
// `struct_array`. The transformation is ordinal. That is, the order of fields in `target_fields` | ||||||
// _must_ match the order of the columns in `struct_array`. | ||||||
fn transform_struct( | ||||||
|
@@ -692,6 +771,85 @@ mod tests { | |||||
assert_eq!(in_result.as_ref(), &in_expected); | ||||||
} | ||||||
|
||||||
#[test] | ||||||
roeap marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
fn test_column_in_array() { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Relating to #652 (comment) and #652 (comment) -- we don't seem to have any tests that cover NULL value semantics? 1 IN (1, NULL) -- TRUE
1 IN (2, NULL) -- NULL
NULL IN (1, 2) -- NULL
1 NOT IN (1, NULL) -- FALSE (**)
1 NOT IN (2, NULL) -- NULL
NULL NOT IN (1, 2) -- NULL (**) NOTE from https://spark.apache.org/docs/3.5.1/sql-ref-null-semantics.html#innot-in-subquery-:
IMO, that explanation is confusing and factually incorrect. If we explain it in terms of NOT(OR): 1 NOT IN (1, NULL)
= NOT(1 IN (1, NULL))
= NOT(1 = 1 OR 1 = NULL)
= NOT(1 = 1) AND NOT(1 = NULL)
= 1 != 1 AND 1 != NULL
= FALSE AND NULL
= FALSE As additional support for my claim: sqlite, postgres, and mysql all return FALSE (not NULL) for that expression. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added some test according to the cases mentioned above, hopefully covering all cases. This uncovered some cases where we were not handling NULLs correctly in the other in-list branches, mainly b/c the arrow kernels don't seem to be adhering to the SQL NULL semantics. In addition to the engines above, I also tried duckdb and datafusion, which also support @scovich's claim. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, is this something worth upstreaming to |
||||||
let values = Int32Array::from(vec![0, 1, 2, 3]); | ||||||
let field = Arc::new(Field::new("item", DataType::Int32, true)); | ||||||
let rhs = Expression::literal(Scalar::Array(ArrayData::new( | ||||||
ArrayType::new(PrimitiveType::Integer.into(), false), | ||||||
[Scalar::Integer(1), Scalar::Integer(3)], | ||||||
))); | ||||||
let schema = Schema::new([field.clone()]); | ||||||
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); | ||||||
|
||||||
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); | ||||||
let in_result = | ||||||
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); | ||||||
let in_expected = BooleanArray::from(vec![false, true, false, true]); | ||||||
assert_eq!(in_result.as_ref(), &in_expected); | ||||||
|
||||||
let not_in_op = Expression::binary(BinaryOperator::NotIn, column_expr!("item"), rhs); | ||||||
let not_in_result = | ||||||
evaluate_expression(¬_in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)) | ||||||
.unwrap(); | ||||||
let not_in_expected = BooleanArray::from(vec![true, false, true, false]); | ||||||
assert_eq!(not_in_result.as_ref(), ¬_in_expected); | ||||||
|
||||||
let in_expected = BooleanArray::from(vec![false, true, false, true]); | ||||||
|
||||||
// Date arrays | ||||||
let values = Date32Array::from(vec![0, 1, 2, 3]); | ||||||
let field = Arc::new(Field::new("item", DataType::Date32, true)); | ||||||
let rhs = Expression::literal(Scalar::Array(ArrayData::new( | ||||||
ArrayType::new(PrimitiveType::Date.into(), false), | ||||||
[Scalar::Date(1), Scalar::Date(3)], | ||||||
))); | ||||||
let schema = Schema::new([field.clone()]); | ||||||
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); | ||||||
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); | ||||||
let in_result = | ||||||
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); | ||||||
assert_eq!(in_result.as_ref(), &in_expected); | ||||||
|
||||||
// Timestamp arrays | ||||||
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]).with_timezone("UTC"); | ||||||
let field = Arc::new(Field::new( | ||||||
"item", | ||||||
(&crate::schema::DataType::TIMESTAMP).try_into().unwrap(), | ||||||
true, | ||||||
)); | ||||||
let rhs = Expression::literal(Scalar::Array(ArrayData::new( | ||||||
ArrayType::new(PrimitiveType::Timestamp.into(), false), | ||||||
[Scalar::Timestamp(1), Scalar::Timestamp(3)], | ||||||
))); | ||||||
let schema = Schema::new([field.clone()]); | ||||||
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); | ||||||
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); | ||||||
let in_result = | ||||||
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); | ||||||
assert_eq!(in_result.as_ref(), &in_expected); | ||||||
|
||||||
// Timestamp NTZ arrays | ||||||
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]); | ||||||
let field = Arc::new(Field::new( | ||||||
"item", | ||||||
(&crate::schema::DataType::TIMESTAMP_NTZ) | ||||||
.try_into() | ||||||
.unwrap(), | ||||||
true, | ||||||
)); | ||||||
let rhs = Expression::literal(Scalar::Array(ArrayData::new( | ||||||
ArrayType::new(PrimitiveType::TimestampNtz.into(), false), | ||||||
[Scalar::TimestampNtz(1), Scalar::TimestampNtz(3)], | ||||||
))); | ||||||
let schema = Schema::new([field.clone()]); | ||||||
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); | ||||||
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); | ||||||
let in_result = | ||||||
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); | ||||||
assert_eq!(in_result.as_ref(), &in_expected); | ||||||
} | ||||||
|
||||||
#[test] | ||||||
fn test_extract_column() { | ||||||
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The data in arrow arrays should always represent a timestamp in UTC, so is this check even necessary?
https://github.com/apache/arrow-rs/blob/af777cd53e56f8382382137b6e08af249c475397/arrow-schema/src/datatype.rs#L179-L182