diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index a045153cf..061cf0a6a 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -9,7 +9,7 @@ repository.workspace = true readme.workspace = true version.workspace = true # exclude golden tests + golden test data since they push us over 10MB crate size limit -exclude = ["tests/golden_tables.rs", "tests/golden_data/" ] +exclude = ["tests/golden_tables.rs", "tests/golden_data/"] rust-version.workspace = true [package.metadata.docs.rs] @@ -17,10 +17,22 @@ all-features = true [package.metadata.release] pre-release-replacements = [ - {file="../README.md", search="delta_kernel = \"[a-z0-9\\.-]+\"", replace="delta_kernel = \"{{version}}\""}, - {file="../README.md", search="version = \"[a-z0-9\\.-]+\"", replace="version = \"{{version}}\""}, + { file = "../README.md", search = "delta_kernel = \"[a-z0-9\\.-]+\"", replace = "delta_kernel = \"{{version}}\"" }, + { file = "../README.md", search = "version = \"[a-z0-9\\.-]+\"", replace = "version = \"{{version}}\"" }, +] +pre-release-hook = [ + "git", + "cliff", + "--repository", + "../", + "--config", + "../cliff.toml", + "--unreleased", + "--prepend", + "../CHANGELOG.md", + "--tag", + "{{version}}", ] -pre-release-hook = ["git", "cliff", "--repository", "../", "--config", "../cliff.toml", "--unreleased", "--prepend", "../CHANGELOG.md", "--tag", "{{version}}" ] [dependencies] bytes = "1.7" @@ -71,10 +83,17 @@ tokio = { version = "1.40", optional = true, features = ["rt-multi-thread"] } # Used in integration tests hdfs-native = { workspace = true, optional = true } walkdir = { workspace = true, optional = true } +rust_decimal = "1.36.0" [features] arrow-conversion = ["arrow-schema"] -arrow-expression = ["arrow-arith", "arrow-array", "arrow-buffer", "arrow-ord", "arrow-schema"] +arrow-expression = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-ord", + "arrow-schema", +] cloud = [ "object_store/aws", "object_store/azure", @@ -106,10 +125,7 @@ default-engine-base = [ # the default-engine use the reqwest crate with default features which uses native-tls. if you want # to instead use rustls, use 'default-engine-rustls' which has no native-tls dependency -default-engine = [ - "default-engine-base", - "reqwest/default", -] +default-engine = ["default-engine-base", "reqwest/default"] default-engine-rustls = [ "default-engine-base", diff --git a/kernel/src/engine/arrow_conversion.rs b/kernel/src/engine/arrow_conversion.rs index 0b905ff3a..103a6279b 100644 --- a/kernel/src/engine/arrow_conversion.rs +++ b/kernel/src/engine/arrow_conversion.rs @@ -208,7 +208,7 @@ impl TryFrom<&ArrowDataType> for DataType { ArrowDataType::Date64 => Ok(DataType::DATE), ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => Ok(DataType::TIMESTAMP_NTZ), ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz)) - if tz.eq_ignore_ascii_case("utc") => + if tz.eq_ignore_ascii_case("utc") || tz.eq_ignore_ascii_case("+00:00") => { Ok(DataType::TIMESTAMP) } diff --git a/kernel/src/engine/arrow_expression.rs b/kernel/src/engine/arrow_expression.rs index 8ee54ebd0..0ed165585 100644 --- a/kernel/src/engine/arrow_expression.rs +++ b/kernel/src/engine/arrow_expression.rs @@ -1,5 +1,6 @@ //! Expression handling based on arrow-rs compute kernels. use std::borrow::Borrow; +use std::cmp::Ordering; use std::collections::HashMap; use std::sync::Arc; @@ -32,6 +33,7 @@ use crate::expressions::{ BinaryExpression, BinaryOperator, Expression, Scalar, UnaryExpression, UnaryOperator, VariadicExpression, VariadicOperator, }; +use crate::predicates::PredicateEvaluatorDefaults; use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, Schema, SchemaRef, StructField}; use crate::{EngineData, ExpressionEvaluator, ExpressionHandler}; @@ -236,14 +238,26 @@ fn evaluate_expression( }), _, ) => match (left.as_ref(), right.as_ref()) { - (Literal(_), Column(_)) => { + (Literal(lit), Column(_)) => { + if lit.is_null() { + return Ok(Arc::new(BooleanArray::from(vec![None; batch.num_rows()]))); + } let left_arr = evaluate_expression(left.as_ref(), batch, None)?; let right_arr = evaluate_expression(right.as_ref(), batch, None)?; if let Some(string_arr) = left_arr.as_string_opt::() { if let Some(right_arr) = right_arr.as_list_opt::() { - return in_list_utf8(string_arr, right_arr) - .map(wrap_comparison_result) - .map_err(Error::generic_err); + let in_list_result = + in_list_utf8(string_arr, right_arr).map_err(Error::generic_err)?; + return Ok(wrap_comparison_result( + in_list_result + .iter() + .zip(right_arr.iter()) + .map(|(res, arr)| match (res, arr) { + (Some(false), Some(arr)) if arr.null_count() > 0 => None, + _ => res, + }) + .collect(), + )); } } prim_array_cmp! { @@ -280,10 +294,93 @@ fn evaluate_expression( (ArrowDataType::Decimal256(_, _), Decimal256Type) } } + (Column(name), Literal(Scalar::Array(ad))) => { + fn op( + values: &dyn Array, + from: fn(T::Native) -> Scalar, + ) -> impl Iterator> + '_ { + values.as_primitive::().iter().map(move |v| v.map(from)) + } + + fn str_op<'a>( + column: impl IntoIterator> + 'a, + ) -> impl Iterator> + 'a { + column.into_iter().map(|v| v.map(Scalar::from)) + } + + fn op_in( + inlist: &[Scalar], + values: impl Iterator>, + ) -> BooleanArray { + // `v IN (k1, ..., kN)` is logically equivalent to `v = k1 OR ... OR v = kN`, so evaluate + // it as such, ensuring correct handling of NULL inputs (including `Scalar::Null`). + values + .map(|v| { + PredicateEvaluatorDefaults::finish_eval_variadic( + VariadicOperator::Or, + inlist + .iter() + .map(|k| Some(v.as_ref()?.partial_cmp(k)? == Ordering::Equal)), + false, + ) + }) + .collect() + } + + #[allow(deprecated)] + let inlist = ad.array_elements(); + let column = extract_column(batch, name)?; + let data_type = ad + .array_type() + .element_type() + .as_primitive_opt() + .ok_or_else(|| { + Error::invalid_expression(format!( + "IN only supports array literals with primitive elements, got: '{:?}'", + ad.array_type().element_type() + )) + })?; + + // safety: as_* methods on arrow arrays can panic, but we checked the data type before applying. + let arr = match (column.data_type(), data_type) { + (ArrowDataType::Utf8, PrimitiveType::String) => op_in(inlist, str_op(column.as_string::())), + (ArrowDataType::LargeUtf8, PrimitiveType::String) => op_in(inlist, str_op(column.as_string::())), + (ArrowDataType::Utf8View, PrimitiveType::String) => op_in(inlist, str_op(column.as_string_view())), + (ArrowDataType::Int8, PrimitiveType::Byte) => op_in(inlist,op::( &column, Scalar::from)), + (ArrowDataType::Int16, PrimitiveType::Short) => op_in(inlist,op::(&column, Scalar::from)), + (ArrowDataType::Int32, PrimitiveType::Integer) => op_in(inlist,op::(&column, Scalar::from)), + (ArrowDataType::Int64, PrimitiveType::Long) => op_in(inlist,op::(&column, Scalar::from)), + (ArrowDataType::Float32, PrimitiveType::Float) => op_in(inlist,op::(&column, Scalar::from)), + (ArrowDataType::Float64, PrimitiveType::Double) => op_in(inlist,op::(&column, Scalar::from)), + (ArrowDataType::Date32, PrimitiveType::Date) => { + op_in(inlist,op::(&column, Scalar::Date)) + }, + ( + ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(_)), + PrimitiveType::Timestamp, + ) => op_in(inlist,op::(column.as_ref(), Scalar::Timestamp)), + ( + ArrowDataType::Timestamp(TimeUnit::Microsecond, None), + PrimitiveType::TimestampNtz, + ) => op_in(inlist,op::(column.as_ref(), Scalar::TimestampNtz)), + (l, r) => { + return Err(Error::invalid_expression(format!( + "Cannot check if value of type '{l}' is contained in array with values of type '{r}'" + ))) + } + }; + Ok(Arc::new(arr)) + } (Literal(lit), Literal(Scalar::Array(ad))) => { #[allow(deprecated)] - let exists = ad.array_elements().contains(lit); - Ok(Arc::new(BooleanArray::from(vec![exists]))) + let exists = PredicateEvaluatorDefaults::finish_eval_variadic( + VariadicOperator::Or, + ad.array_elements() + .iter() + .map(|k| Some(lit.partial_cmp(k)? == Ordering::Equal)), + false, + ); + Ok(Arc::new(BooleanArray::from(vec![exists; batch.num_rows()]))) } (l, r) => Err(Error::invalid_expression(format!( "Invalid right value for (NOT) IN comparison, left is: {l} right is: {r}" @@ -382,8 +479,8 @@ fn new_field_with_metadata( // A helper that is a wrapper over `transform_field_and_col`. This will take apart the passed struct // and use that method to transform each column and then put the struct back together. Target types -// and names for each column should be passed in `target_types_and_names`. The number of elements in -// the `target_types_and_names` iterator _must_ be the same as the number of columns in +// and names for each column should be passed in `target_fields`. The number of elements in +// the `target_fields` iterator _must_ be the same as the number of columns in // `struct_array`. The transformation is ordinal. That is, the order of fields in `target_fields` // _must_ match the order of the columns in `struct_array`. fn transform_struct( @@ -590,13 +687,11 @@ mod tests { let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); let not_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item")); - - let in_op = Expression::binary(BinaryOperator::In, 5, column_expr!("item")); - let result = evaluate_expression(¬_op, &batch, None).unwrap(); let expected = BooleanArray::from(vec![true, false, true]); assert_eq!(result.as_ref(), &expected); + let in_op = Expression::binary(BinaryOperator::In, 5, column_expr!("item")); let in_result = evaluate_expression(&in_op, &batch, None).unwrap(); let in_expected = BooleanArray::from(vec![false, true, false]); assert_eq!(in_result.as_ref(), &in_expected); @@ -621,7 +716,7 @@ mod tests { } #[test] - fn test_literal_type_array() { + fn test_literal_type_array_empty() { let field = Arc::new(Field::new("item", DataType::Int32, true)); let schema = Schema::new([field.clone()]); let batch = RecordBatch::new_empty(Arc::new(schema)); @@ -636,7 +731,7 @@ mod tests { ); let in_result = evaluate_expression(&in_op, &batch, None).unwrap(); - let in_expected = BooleanArray::from(vec![true]); + let in_expected = BooleanArray::from(Vec::>::new()); assert_eq!(in_result.as_ref(), &in_expected); } @@ -692,6 +787,250 @@ mod tests { assert_eq!(in_result.as_ref(), &in_expected); } + #[test] + fn test_str_arrays_with_null() { + let values = GenericStringArray::::from(vec![ + Some("one"), + None, + Some("two"), + None, + Some("one"), + Some("two"), + ]); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 6])); + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true)); + let schema = Schema::new([arr_field.clone()]); + let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); + + let in_op = Expression::binary(BinaryOperator::In, "one", column_expr!("item")); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![Some(true), None, Some(true)]); + assert_eq!(in_result.as_ref(), &in_expected); + + let in_op = Expression::binary(BinaryOperator::In, "two", column_expr!("item")); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![None, Some(true), Some(true)]); + assert_eq!(in_result.as_ref(), &in_expected); + + let in_op = Expression::binary( + BinaryOperator::In, + Scalar::Null(DeltaDataTypes::STRING), + column_expr!("item"), + ); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![None, None, None]); + assert_eq!(in_result.as_ref(), &in_expected); + } + + #[test] + fn test_arrays_with_null() { + let values = Int32Array::from(vec![Some(1), None, Some(2), None, Some(1), Some(2)]); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 6])); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let arr_field = Arc::new(Field::new("item", DataType::List(field.clone()), true)); + let schema = Schema::new([arr_field.clone()]); + let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); + + let in_op = Expression::binary(BinaryOperator::In, 1, column_expr!("item")); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![Some(true), None, Some(true)]); + assert_eq!(in_result.as_ref(), &in_expected); + + let in_op = Expression::binary(BinaryOperator::In, 2, column_expr!("item")); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![None, Some(true), Some(true)]); + assert_eq!(in_result.as_ref(), &in_expected); + + let in_op = Expression::binary( + BinaryOperator::In, + Scalar::Null(DeltaDataTypes::INTEGER), + column_expr!("item"), + ); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![None, None, None]); + assert_eq!(in_result.as_ref(), &in_expected); + } + + #[test] + fn test_column_in_array() { + let values = Int32Array::from(vec![0, 1, 2, 3]); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let rhs = Expression::literal(Scalar::Array(ArrayData::new( + ArrayType::new(PrimitiveType::Integer.into(), false), + [Scalar::Integer(1), Scalar::Integer(3)], + ))); + let schema = Schema::new([field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); + + let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![false, true, false, true]); + assert_eq!(in_result.as_ref(), &in_expected); + + let not_in_op = Expression::binary(BinaryOperator::NotIn, column_expr!("item"), rhs); + let not_in_result = + evaluate_expression(¬_in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let not_in_expected = BooleanArray::from(vec![true, false, true, false]); + assert_eq!(not_in_result.as_ref(), ¬_in_expected); + + let in_expected = BooleanArray::from(vec![false, true, false, true]); + + // Date arrays + let values = Date32Array::from(vec![0, 1, 2, 3]); + let field = Arc::new(Field::new("item", DataType::Date32, true)); + let rhs = Expression::literal(Scalar::Array(ArrayData::new( + ArrayType::new(PrimitiveType::Date.into(), false), + [Scalar::Date(1), Scalar::Date(3)], + ))); + let schema = Schema::new([field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); + let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + assert_eq!(in_result.as_ref(), &in_expected); + + // Timestamp arrays + let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]).with_timezone("UTC"); + let field = Arc::new(Field::new( + "item", + (&DeltaDataTypes::TIMESTAMP).try_into().unwrap(), + true, + )); + let rhs = Expression::literal(Scalar::Array(ArrayData::new( + ArrayType::new(PrimitiveType::Timestamp.into(), false), + [Scalar::Timestamp(1), Scalar::Timestamp(3)], + ))); + let schema = Schema::new([field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); + let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + assert_eq!(in_result.as_ref(), &in_expected); + + // Timestamp NTZ arrays + let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]); + let field = Arc::new(Field::new( + "item", + (&DeltaDataTypes::TIMESTAMP_NTZ).try_into().unwrap(), + true, + )); + let rhs = Expression::literal(Scalar::Array(ArrayData::new( + ArrayType::new(PrimitiveType::TimestampNtz.into(), false), + [Scalar::TimestampNtz(1), Scalar::TimestampNtz(3)], + ))); + let schema = Schema::new([field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); + let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + assert_eq!(in_result.as_ref(), &in_expected); + } + + #[test] + fn test_column_in_array_with_null() { + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let values = Int32Array::from(vec![Some(1), Some(2), None]); + let schema = Schema::new([field.clone()]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values.clone())]).unwrap(); + + let rhs = Expression::literal(Scalar::Array(ArrayData::new( + ArrayType::new(PrimitiveType::Integer.into(), true), + [Scalar::Integer(1), Scalar::Null(DeltaDataTypes::INTEGER)], + ))); + + // item IN (1, NULL) -- TRUE, NULL, NULL + let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![Some(true), None, None]); + assert_eq!(in_result.as_ref(), &in_expected); + + // 1 IN (1, NULL) -- TRUE + let in_op = Expression::binary(BinaryOperator::In, Scalar::Integer(1), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![Some(true), Some(true), Some(true)]); + assert_eq!(in_result.as_ref(), &in_expected); + + // 1 NOT IN (1, NULL) -- FALSE + let in_op = Expression::binary(BinaryOperator::NotIn, Scalar::Integer(1), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![Some(false), Some(false), Some(false)]); + assert_eq!(in_result.as_ref(), &in_expected); + + let rhs = Expression::literal(Scalar::Array(ArrayData::new( + ArrayType::new(PrimitiveType::Integer.into(), true), + [Scalar::Integer(2), Scalar::Null(DeltaDataTypes::INTEGER)], + ))); + + // item IN (2, NULL) -- NULL, TRUE, NULL + let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![None, Some(true), None]); + assert_eq!(in_result.as_ref(), &in_expected); + + let in_expected = BooleanArray::from(vec![None, None, None]); + + // 1 IN (2, NULL) -- NULL + let in_op = Expression::binary(BinaryOperator::In, Scalar::Integer(1), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + assert_eq!(in_result.as_ref(), &in_expected); + + // 1 NOT IN (2, NULL) -- NULL + let in_op = Expression::binary(BinaryOperator::NotIn, Scalar::Integer(1), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + assert_eq!(in_result.as_ref(), &in_expected); + + let rhs = Expression::literal(Scalar::Array(ArrayData::new( + ArrayType::new(PrimitiveType::Integer.into(), true), + [Scalar::Integer(1), Scalar::Integer(2)], + ))); + + // item IN (1, 2) -- TRUE, TRUE, NULL + let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone()); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + let in_expected = BooleanArray::from(vec![Some(true), Some(true), None]); + assert_eq!(in_result.as_ref(), &in_expected); + + let in_expected = BooleanArray::from(vec![None, None, None]); + + // NULL IN (1, 2) -- NULL + let in_op = Expression::binary( + BinaryOperator::In, + Scalar::Null(DeltaDataTypes::INTEGER), + rhs.clone(), + ); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + assert_eq!(in_result.as_ref(), &in_expected); + + // NULL NOT IN (1, 2) -- NULL + let in_op = Expression::binary( + BinaryOperator::NotIn, + Scalar::Null(DeltaDataTypes::INTEGER), + rhs.clone(), + ); + let in_result = + evaluate_expression(&in_op, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); + assert_eq!(in_result.as_ref(), &in_expected); + } + #[test] fn test_extract_column() { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -841,29 +1180,25 @@ mod tests { let expression = column_a.clone().and(column_b.clone()); let results = - evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN)) - .unwrap(); + evaluate_expression(&expression, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); let expected = Arc::new(BooleanArray::from(vec![false, false])); assert_eq!(results.as_ref(), expected.as_ref()); let expression = column_a.clone().and(true); let results = - evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN)) - .unwrap(); + evaluate_expression(&expression, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); let expected = Arc::new(BooleanArray::from(vec![true, false])); assert_eq!(results.as_ref(), expected.as_ref()); let expression = column_a.clone().or(column_b); let results = - evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN)) - .unwrap(); + evaluate_expression(&expression, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); let expected = Arc::new(BooleanArray::from(vec![true, true])); assert_eq!(results.as_ref(), expected.as_ref()); let expression = column_a.clone().or(false); let results = - evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN)) - .unwrap(); + evaluate_expression(&expression, &batch, Some(&DeltaDataTypes::BOOLEAN)).unwrap(); let expected = Arc::new(BooleanArray::from(vec![true, false])); assert_eq!(results.as_ref(), expected.as_ref()); } diff --git a/kernel/src/engine/arrow_utils.rs b/kernel/src/engine/arrow_utils.rs index 06441b9d4..01ec88606 100644 --- a/kernel/src/engine/arrow_utils.rs +++ b/kernel/src/engine/arrow_utils.rs @@ -33,23 +33,33 @@ macro_rules! prim_array_cmp { $( $data_ty => { let prim_array = $left_arr.as_primitive_opt::<$prim_ty>() - .ok_or(Error::invalid_expression( - format!("Cannot cast to primitive array: {}", $left_arr.data_type())) - )?; - let list_array = $right_arr.as_list_opt::() - .ok_or(Error::invalid_expression( - format!("Cannot cast to list array: {}", $right_arr.data_type())) - )?; - arrow_ord::comparison::in_list(prim_array, list_array).map(wrap_comparison_result) + .ok_or(Error::invalid_expression( + format!("Cannot cast to primitive array: {}", $left_arr.data_type())) + )?; + let list_array = $right_arr.as_list_opt::() + .ok_or(Error::invalid_expression( + format!("Cannot cast to list array: {}", $right_arr.data_type())) + )?; + let in_list_result = arrow_ord::comparison::in_list(prim_array, list_array).map_err(Error::generic_err)?; + Ok(wrap_comparison_result( + in_list_result + .iter() + .zip(list_array.iter()) + .map(|(res, arr)| match (res, arr) { + (Some(false), Some(arr)) if arr.null_count() > 0 => None, + _ => res, + }) + .collect(), + )) } )+ - _ => Err(ArrowError::CastError( - format!("Bad Comparison between: {:?} and {:?}", - $left_arr.data_type(), - $right_arr.data_type()) - ) + _ => Err(Error::invalid_expression( + format!("Bad Comparison between: {:?} and {:?}", + $left_arr.data_type(), + $right_arr.data_type()) ) - }.map_err(Error::generic_err); + ) + }; }; } diff --git a/kernel/src/expressions/scalars.rs b/kernel/src/expressions/scalars.rs index 2ce2fd41a..dc467d7eb 100644 --- a/kernel/src/expressions/scalars.rs +++ b/kernel/src/expressions/scalars.rs @@ -89,7 +89,7 @@ impl StructData { /// A single value, which can be null. Used for representing literal values /// in [Expressions][crate::expressions::Expression]. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub enum Scalar { /// 32bit integer Integer(i32), @@ -224,6 +224,12 @@ impl Display for Scalar { } } +impl PartialEq for Scalar { + fn eq(&self, other: &Scalar) -> bool { + self.partial_cmp(other) == Some(Ordering::Equal) + } +} + impl PartialOrd for Scalar { fn partial_cmp(&self, other: &Self) -> Option { use Scalar::*; @@ -254,10 +260,15 @@ impl PartialOrd for Scalar { (Date(_), _) => None, (Binary(a), Binary(b)) => a.partial_cmp(b), (Binary(_), _) => None, - (Decimal(_, _, _), _) => None, // TODO: Support Decimal - (Null(_), _) => None, // NOTE: NULL values are incomparable by definition - (Struct(_), _) => None, // TODO: Support Struct? - (Array(_), _) => None, // TODO: Support Array? + (Decimal(v1, _, s1), Decimal(v2, _, s2)) => { + let lhs = rust_decimal::Decimal::from_i128_with_scale(*v1, *s1 as u32); + let rhs = rust_decimal::Decimal::from_i128_with_scale(*v2, *s2 as u32); + lhs.partial_cmp(&rhs) + } + (Decimal(_, _, _), _) => None, + (Null(_), _) => None, // NOTE: NULL values are incomparable by definition + (Struct(_), _) => None, // TODO: Support Struct? + (Array(_), _) => None, // TODO: Support Array? } } } @@ -585,6 +596,7 @@ mod tests { assert_eq!(&format!("{}", column_op), "3.1415927 IN Column(item)"); assert_eq!(&format!("{}", column_not_op), "'Cool' NOT IN Column(item)"); } + #[test] fn test_timestamp_parse() { let assert_timestamp_eq = |scalar_string, micros| { @@ -599,6 +611,7 @@ mod tests { assert_timestamp_eq("2011-01-11 13:06:07.123456", 1294751167123456); assert_timestamp_eq("1970-01-01 00:00:00", 0); } + #[test] fn test_timestamp_ntz_parse() { let assert_timestamp_eq = |scalar_string, micros| { @@ -627,4 +640,36 @@ mod tests { let p_type = PrimitiveType::Timestamp; assert_timestamp_fails(&p_type, "1971-07-22"); } + + #[test] + fn test_partial_cmp() { + let a = Scalar::Integer(1); + let b = Scalar::Integer(2); + let c = Scalar::Null(DataType::INTEGER); + assert_eq!(a.partial_cmp(&b), Some(Ordering::Less)); + assert_eq!(b.partial_cmp(&a), Some(Ordering::Greater)); + assert_eq!(a.partial_cmp(&a), Some(Ordering::Equal)); + assert_eq!(b.partial_cmp(&b), Some(Ordering::Equal)); + assert_eq!(a.partial_cmp(&c), None); + assert_eq!(c.partial_cmp(&a), None); + + // assert that NULL values are incomparable + let null = Scalar::Null(DataType::INTEGER); + assert_eq!(null.partial_cmp(&null), None); + } + + #[test] + fn test_partial_eq() { + let a = Scalar::Integer(1); + let b = Scalar::Integer(2); + let c = Scalar::Null(DataType::INTEGER); + assert!(!a.eq(&b)); + assert!(a.eq(&a)); + assert!(!a.eq(&c)); + assert!(!c.eq(&a)); + + // assert that NULL values are incomparable + let null = Scalar::Null(DataType::INTEGER); + assert!(!null.eq(&null)); + } } diff --git a/kernel/src/predicates/tests.rs b/kernel/src/predicates/tests.rs index b57d0759a..fdeda8305 100644 --- a/kernel/src/predicates/tests.rs +++ b/kernel/src/predicates/tests.rs @@ -117,7 +117,7 @@ fn test_default_partial_cmp_scalars() { } let expect_if_comparable_type = |s: &_, expect| match s { - Null(_) | Decimal(..) | Struct(_) | Array(_) => None, + Null(_) | Struct(_) | Array(_) => None, _ => Some(expect), }; @@ -609,7 +609,7 @@ fn test_sql_where() { expect_eq!(null_filter.eval_sql_where(expr), Some(true), "{expr}"); expect_eq!(empty_filter.eval_sql_where(expr), Some(true), "{expr}"); - // Constrast normal vs SQL WHERE semantics - comparison + // Contrast normal vs SQL WHERE semantics - comparison let expr = &Expr::lt(col.clone(), VAL); expect_eq!(null_filter.eval(expr), None, "{expr}"); expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); @@ -631,7 +631,7 @@ fn test_sql_where() { expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); - // Constrast normal vs SQL WHERE semantics - comparison inside AND + // Contrast normal vs SQL WHERE semantics - comparison inside AND let expr = &Expr::and(TRUE, Expr::lt(col.clone(), VAL)); expect_eq!(null_filter.eval(expr), None, "{expr}"); expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); diff --git a/kernel/src/schema/compare.rs b/kernel/src/schema/compare.rs index e465f1618..eb65540cf 100644 --- a/kernel/src/schema/compare.rs +++ b/kernel/src/schema/compare.rs @@ -56,7 +56,7 @@ pub(crate) type SchemaComparisonResult = Result<(), Error>; /// Represents a schema compatibility check for the type. If `self` can be read as `read_type`, /// this function returns `Ok(())`. Otherwise, this function returns `Err`. /// -/// TODO (Oussama): Remove the `allow(unsued)` once this is used in CDF. +/// TODO (Oussama): Remove the `allow(unused)` once this is used in CDF. #[allow(unused)] pub(crate) trait SchemaComparison { fn can_read_as(&self, read_type: &Self) -> SchemaComparisonResult;