-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support 'col IN (a, b, c)' type expressions #652
base: main
Are you sure you want to change the base?
Changes from 1 commit
28a5648
d6e3730
848ef11
290d65d
def21c1
6b959eb
4e3e92c
ba3b4e9
233c0e8
6c813e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,6 +280,81 @@ fn evaluate_expression( | |
(ArrowDataType::Decimal256(_, _), Decimal256Type) | ||
} | ||
} | ||
(Column(name), Literal(Scalar::Array(ad))) => { | ||
use crate::expressions::ArrayData; | ||
|
||
let column = extract_column(batch, name)?; | ||
let data_type = ad | ||
.array_type() | ||
.element_type() | ||
.as_primitive_opt() | ||
.ok_or_else(|| { | ||
Error::invalid_expression(format!( | ||
"IN only supports array literals with primitive elements, got: '{:?}'", | ||
ad.array_type().element_type() | ||
)) | ||
})?; | ||
|
||
fn op( | ||
col: impl Iterator<Item = Option<impl Into<Scalar>>>, | ||
roeap marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ad: &ArrayData, | ||
) -> BooleanArray { | ||
#[allow(deprecated)] | ||
let res = col.map(|val| val.map(|v| ad.array_elements().contains(&v.into()))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this handles NULL values correctly? See e.g. https://spark.apache.org/docs/3.5.1/sql-ref-null-semantics.html#innot-in-subquery-:
I think, instead of calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I think you could just invoke that method directly, with a properly crafted iterator? // `v IN (k1, ..., kN)` is logically equivalent to `v = k1 OR ... OR v = kN`, so evaluate
// it as such, ensuring correct handling of NULL inputs (including `Scalar::Null`).
col.map(|v| {
PredicateEvaluatorDefaults::finish_eval_variadic(
VariadicOperator::Or,
inlist.iter().map(Some(Scalar::partial_cmp(v?, k?)? == Ordering::Equal)),
false,
)
}) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was I correct in thinking that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aside: We actually have a lurking bug -- Our manual impl of
Looks like we'll need to define a manual There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is indeed not covered. Added an implementation for |
||
BooleanArray::from_iter(res) | ||
} | ||
roeap marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// safety: as_* methods on arrow arrays can panic, but we checked the data type before applying. | ||
let arr: BooleanArray = match (column.data_type(), data_type) { | ||
roeap marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(ArrowDataType::Utf8, PrimitiveType::String) => op(column.as_string::<i32>().iter(), ad), | ||
(ArrowDataType::LargeUtf8, PrimitiveType::String) => op(column.as_string::<i64>().iter(), ad), | ||
(ArrowDataType::Utf8View, PrimitiveType::String) => op(column.as_string_view().iter(), ad), | ||
(ArrowDataType::Int8, PrimitiveType::Byte) => op(column.as_primitive::<Int8Type>().iter(), ad), | ||
(ArrowDataType::Int16, PrimitiveType::Short) => op(column.as_primitive::<Int16Type>().iter(), ad), | ||
(ArrowDataType::Int32, PrimitiveType::Integer) => op(column.as_primitive::<Int32Type>().iter(), ad), | ||
(ArrowDataType::Int64, PrimitiveType::Long) => op(column.as_primitive::<Int64Type>().iter(), ad), | ||
(ArrowDataType::Float32, PrimitiveType::Float) => op(column.as_primitive::<Float32Type>().iter(), ad), | ||
(ArrowDataType::Float64, PrimitiveType::Double) => op(column.as_primitive::<Float64Type>().iter(), ad), | ||
(ArrowDataType::Date32, PrimitiveType::Date) => { | ||
#[allow(deprecated)] | ||
let res = column | ||
.as_primitive::<Date32Type>() | ||
.iter() | ||
.map(|val| val.map(|v| ad.array_elements().contains(&Scalar::Date(v)))); | ||
BooleanArray::from_iter(res) | ||
} | ||
( | ||
ArrowDataType::Timestamp(TimeUnit::Microsecond, unit), | ||
kt @ PrimitiveType::Timestamp | kt @ PrimitiveType::TimestampNtz, | ||
) => { | ||
let res = column.as_primitive::<TimestampMicrosecondType>().iter(); | ||
match (unit, kt) { | ||
// regardless of the time zone stored in the timestamp, the underlying value is always in UTC | ||
(Some(_), PrimitiveType::Timestamp) => { | ||
BooleanArray::from_iter(res.map(|val| { | ||
#[allow(deprecated)] | ||
val.map(|v| ad.array_elements().contains(&Scalar::Timestamp(v))) | ||
})) | ||
} | ||
(None, PrimitiveType::TimestampNtz) => { | ||
BooleanArray::from_iter(res.map(|val| { | ||
val.map(|v| { | ||
#[allow(deprecated)] | ||
ad.array_elements().contains(&Scalar::TimestampNtz(v)) | ||
}) | ||
})) | ||
} | ||
_ => unreachable!(), | ||
} | ||
} | ||
scovich marked this conversation as resolved.
Show resolved
Hide resolved
|
||
(l, r) => { | ||
return Err(Error::invalid_expression(format!( | ||
"Cannot check if value of type '{l}' is contained in array with values of type '{r}'" | ||
))) | ||
} | ||
}; | ||
Ok(Arc::new(arr)) | ||
} | ||
(Literal(lit), Literal(Scalar::Array(ad))) => { | ||
#[allow(deprecated)] | ||
let exists = ad.array_elements().contains(lit); | ||
|
@@ -382,8 +457,8 @@ fn new_field_with_metadata( | |
|
||
// A helper that is a wrapper over `transform_field_and_col`. This will take apart the passed struct | ||
// and use that method to transform each column and then put the struct back together. Target types | ||
// and names for each column should be passed in `target_types_and_names`. The number of elements in | ||
// the `target_types_and_names` iterator _must_ be the same as the number of columns in | ||
// and names for each column should be passed in `target_fields`. The number of elements in | ||
// the `target_fields` iterator _must_ be the same as the number of columns in | ||
// `struct_array`. The transformation is ordinal. That is, the order of fields in `target_fields` | ||
// _must_ match the order of the columns in `struct_array`. | ||
fn transform_struct( | ||
|
@@ -692,6 +767,85 @@ mod tests { | |
assert_eq!(in_result.as_ref(), &in_expected); | ||
} | ||
|
||
#[test] | ||
roeap marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fn test_column_in_array() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Relating to #652 (comment) and #652 (comment) -- we don't seem to have any tests that cover NULL value semantics? 1 IN (1, NULL) -- TRUE
1 IN (2, NULL) -- NULL
NULL IN (1, 2) -- NULL
1 NOT IN (1, NULL) -- FALSE (**)
1 NOT IN (2, NULL) -- NULL
NULL NOT IN (1, 2) -- NULL (**) NOTE from https://spark.apache.org/docs/3.5.1/sql-ref-null-semantics.html#innot-in-subquery-:
IMO, that explanation is confusing and factually incorrect. If we explain it in terms of NOT(OR): 1 NOT IN (1, NULL)
= NOT(1 IN (1, NULL))
= NOT(1 = 1 OR 1 = NULL)
= NOT(1 = 1) AND NOT(1 = NULL)
= 1 != 1 AND 1 != NULL
= FALSE AND NULL
= FALSE As additional support for my claim: sqlite, postgres, and mysql all return FALSE (not NULL) for that expression. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added some test according to the cases mentioned above, hopefully covering all cases. This uncovered some cases where we were not handling NULLs correctly in the other in-list branches, mainly b/c the arrow kernels don't seem to be adhering to the SQL NULL semantics. In addition to the engines above, I also tried duckdb and datafusion, which also support @scovich's claim. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also, is this something worth upstreaming to |
||
let values = Int32Array::from(vec![0, 1, 2, 3]); | ||
let field = Arc::new(Field::new("item", DataType::Int32, true)); | ||
let rhs = Expression::literal(Scalar::Array(ArrayData::new( | ||
ArrayType::new(PrimitiveType::Integer.into(), false), | ||
[Scalar::Integer(1), Scalar::Integer(3)], | ||
))); | ||
let schema = Schema::new([field.clone()]); | ||
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); | ||
|
||
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); | ||
let in_result = | ||
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); | ||
let in_expected = BooleanArray::from(vec![false, true, false, true]); | ||
assert_eq!(in_result.as_ref(), &in_expected); | ||
|
||
let not_in_op = Expression::binary(BinaryOperator::NotIn, column_expr!("item"), rhs); | ||
let not_in_result = | ||
evaluate_expression(¬_in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)) | ||
.unwrap(); | ||
let not_in_expected = BooleanArray::from(vec![true, false, true, false]); | ||
assert_eq!(not_in_result.as_ref(), ¬_in_expected); | ||
|
||
let in_expected = BooleanArray::from(vec![false, true, false, true]); | ||
|
||
// Date arrays | ||
let values = Date32Array::from(vec![0, 1, 2, 3]); | ||
let field = Arc::new(Field::new("item", DataType::Date32, true)); | ||
let rhs = Expression::literal(Scalar::Array(ArrayData::new( | ||
ArrayType::new(PrimitiveType::Date.into(), false), | ||
[Scalar::Date(1), Scalar::Date(3)], | ||
))); | ||
let schema = Schema::new([field.clone()]); | ||
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); | ||
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); | ||
let in_result = | ||
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); | ||
assert_eq!(in_result.as_ref(), &in_expected); | ||
|
||
// Timestamp arrays | ||
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]).with_timezone("UTC"); | ||
let field = Arc::new(Field::new( | ||
"item", | ||
(&crate::schema::DataType::TIMESTAMP).try_into().unwrap(), | ||
true, | ||
)); | ||
let rhs = Expression::literal(Scalar::Array(ArrayData::new( | ||
ArrayType::new(PrimitiveType::Timestamp.into(), false), | ||
[Scalar::Timestamp(1), Scalar::Timestamp(3)], | ||
))); | ||
let schema = Schema::new([field.clone()]); | ||
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); | ||
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); | ||
let in_result = | ||
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); | ||
assert_eq!(in_result.as_ref(), &in_expected); | ||
|
||
// Timestamp NTZ arrays | ||
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]); | ||
let field = Arc::new(Field::new( | ||
"item", | ||
(&crate::schema::DataType::TIMESTAMP_NTZ) | ||
.try_into() | ||
.unwrap(), | ||
true, | ||
)); | ||
let rhs = Expression::literal(Scalar::Array(ArrayData::new( | ||
ArrayType::new(PrimitiveType::TimestampNtz.into(), false), | ||
[Scalar::TimestampNtz(1), Scalar::TimestampNtz(3)], | ||
))); | ||
let schema = Schema::new([field.clone()]); | ||
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); | ||
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); | ||
let in_result = | ||
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap(); | ||
assert_eq!(in_result.as_ref(), &in_expected); | ||
} | ||
|
||
#[test] | ||
fn test_extract_column() { | ||
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The data in arrow arrays should always represent a timestamp in UTC, so is this check even necessary?
https://github.com/apache/arrow-rs/blob/af777cd53e56f8382382137b6e08af249c475397/arrow-schema/src/datatype.rs#L179-L182