From 1582e8d9cc0d307bdc3311cd22566c66fb6b840f Mon Sep 17 00:00:00 2001 From: Simon Vandel Sillesen Date: Sun, 13 Oct 2024 06:02:36 +0200 Subject: [PATCH 01/19] Optimize `iszero` function (3-5x faster) (#12881) * add bench * Optimize iszero function (3-5x) faster --- datafusion/functions/Cargo.toml | 5 +++ datafusion/functions/benches/iszero.rs | 46 +++++++++++++++++++++++++ datafusion/functions/src/math/iszero.rs | 24 +++++-------- 3 files changed, 60 insertions(+), 15 deletions(-) create mode 100644 datafusion/functions/benches/iszero.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index a3d114221d3f..2ffe93a0e567 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -117,6 +117,11 @@ harness = false name = "make_date" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "iszero" +required-features = ["math_expressions"] + [[bench]] harness = false name = "nullif" diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs new file mode 100644 index 000000000000..3348d172e1f2 --- /dev/null +++ b/datafusion/functions/benches/iszero.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::iszero; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let iszero = iszero(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("iszero f32 array: {}", size), |b| { + b.iter(|| black_box(iszero.invoke(&f32_args).unwrap())) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("iszero f64 array: {}", size), |b| { + b.iter(|| black_box(iszero.invoke(&f64_args).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index e6a728053359..74611b65aaba 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -18,11 +18,11 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray, BooleanArray}; use arrow::datatypes::DataType::{Boolean, Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -77,20 +77,14 @@ impl ScalarUDFImpl for IsZeroFunc { /// Iszero SQL function pub fn iszero(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { |x: f64| { x == 0_f64 } } + Float64 => Ok(Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x == 0.0, )) as ArrayRef), - Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { |x: f32| { x == 0_f32 } } + Float32 => Ok(Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x == 0.0, )) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function iszero"), From ebfc15506c34533a33602931463ad6d74c803551 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 13 Oct 2024 07:21:50 -0400 Subject: [PATCH 02/19] Macro for creating record batch from literal slice (#12846) * Add macro for creating record batch, useful for unit test or rapid development * Update docstring * Add additional checks in unit test and rename macro per user input --- datafusion/common/src/test_util.rs | 120 +++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 36254192550c..422fcb5eb3e0 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -279,8 +279,88 @@ pub fn get_data_dir( } } +#[macro_export] +macro_rules! create_array { + (Boolean, $values: expr) => { + std::sync::Arc::new(arrow::array::BooleanArray::from($values)) + }; + (Int8, $values: expr) => { + std::sync::Arc::new(arrow::array::Int8Array::from($values)) + }; + (Int16, $values: expr) => { + std::sync::Arc::new(arrow::array::Int16Array::from($values)) + }; + (Int32, $values: expr) => { + std::sync::Arc::new(arrow::array::Int32Array::from($values)) + }; + (Int64, $values: expr) => { + std::sync::Arc::new(arrow::array::Int64Array::from($values)) + }; + (UInt8, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt8Array::from($values)) + }; + (UInt16, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt16Array::from($values)) + }; + (UInt32, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt32Array::from($values)) + }; + (UInt64, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt64Array::from($values)) + }; + (Float16, $values: expr) => { + std::sync::Arc::new(arrow::array::Float16Array::from($values)) + }; + (Float32, $values: expr) => { + std::sync::Arc::new(arrow::array::Float32Array::from($values)) + }; + (Float64, $values: expr) => { + std::sync::Arc::new(arrow::array::Float64Array::from($values)) + }; + (Utf8, $values: expr) => { + std::sync::Arc::new(arrow::array::StringArray::from($values)) + }; +} + +/// Creates a record batch from literal slice of values, suitable for rapid +/// testing and development. +/// +/// Example: +/// ``` +/// use datafusion_common::{record_batch, create_array}; +/// let batch = record_batch!( +/// ("a", Int32, vec![1, 2, 3]), +/// ("b", Float64, vec![Some(4.0), None, Some(5.0)]), +/// ("c", Utf8, vec!["alpha", "beta", "gamma"]) +/// ); +/// ``` +#[macro_export] +macro_rules! record_batch { + ($(($name: expr, $type: ident, $values: expr)),*) => { + { + let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![ + $( + arrow_schema::Field::new($name, arrow_schema::DataType::$type, true), + )* + ])); + + let batch = arrow_array::RecordBatch::try_new( + schema, + vec![$( + create_array!($type, $values), + )*] + ); + + batch + } + } +} + #[cfg(test)] mod tests { + use crate::cast::{as_float64_array, as_int32_array, as_string_array}; + use crate::error::Result; + use super::*; use std::env; @@ -333,4 +413,44 @@ mod tests { let res = parquet_test_data(); assert!(PathBuf::from(res).is_dir()); } + + #[test] + fn test_create_record_batch() -> Result<()> { + use arrow_array::Array; + + let batch = record_batch!( + ("a", Int32, vec![1, 2, 3, 4]), + ("b", Float64, vec![Some(4.0), None, Some(5.0), None]), + ("c", Utf8, vec!["alpha", "beta", "gamma", "delta"]) + )?; + + assert_eq!(3, batch.num_columns()); + assert_eq!(4, batch.num_rows()); + + let values: Vec<_> = as_int32_array(batch.column(0))? + .values() + .iter() + .map(|v| v.to_owned()) + .collect(); + assert_eq!(values, vec![1, 2, 3, 4]); + + let values: Vec<_> = as_float64_array(batch.column(1))? + .values() + .iter() + .map(|v| v.to_owned()) + .collect(); + assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]); + + let nulls: Vec<_> = as_float64_array(batch.column(1))? + .nulls() + .unwrap() + .iter() + .collect(); + assert_eq!(nulls, vec![true, false, true, false]); + + let values: Vec<_> = as_string_array(batch.column(2))?.iter().flatten().collect(); + assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]); + + Ok(()) + } } From 646f40a44330cdcfad5fc779897046d1dc0b83c5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 13 Oct 2024 08:07:12 -0400 Subject: [PATCH 03/19] Implement special min/max accumulator for Strings and Binary (10% faster for Clickbench Q28) (#12792) * Implement special min/max accumulator for Strings: `MinMaxBytesAccumulator` * fix bug * fix msrv * move code, handle filters * simplify * Add functional tests * remove unecessary test * improve docs * improve docs * cleanup * improve comments * fix diagram * fix accounting * Use correct type in memory accounting * Add TODO comment --- .../groups_accumulator/accumulate.rs | 2 +- .../src/aggregate/groups_accumulator/nulls.rs | 115 +++- datafusion/functions-aggregate/src/min_max.rs | 123 +++-- .../src/min_max/min_max_bytes.rs | 515 ++++++++++++++++++ .../sqllogictest/test_files/aggregate.slt | 174 ++++++ 5 files changed, 872 insertions(+), 57 deletions(-) create mode 100644 datafusion/functions-aggregate/src/min_max/min_max_bytes.rs diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index a0475fe8e446..3efd348937ed 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -95,7 +95,7 @@ impl NullState { /// /// When value_fn is called it also sets /// - /// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale + /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value pub fn accumulate( &mut self, group_indices: &[usize], diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 25212f7f0f5f..6a8946034cbc 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -15,13 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! [`set_nulls`], and [`filtered_null_mask`], utilities for working with nulls +//! [`set_nulls`], other utilities for working with nulls -use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray}; +use arrow::array::{ + Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray, + BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, + StringViewArray, +}; use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::{not_impl_err, Result}; +use std::sync::Arc; /// Sets the validity mask for a `PrimitiveArray` to `nulls` /// replacing any existing null mask +/// +/// See [`set_nulls_dyn`] for a version that works with `Array` pub fn set_nulls( array: PrimitiveArray, nulls: Option, @@ -91,3 +100,105 @@ pub fn filtered_null_mask( let opt_filter = opt_filter.and_then(filter_to_nulls); NullBuffer::union(opt_filter.as_ref(), input.nulls()) } + +/// Applies optional filter to input, returning a new array of the same type +/// with the same data, but with any values that were filtered out set to null +pub fn apply_filter_as_nulls( + input: &dyn Array, + opt_filter: Option<&BooleanArray>, +) -> Result { + let nulls = filtered_null_mask(opt_filter, input); + set_nulls_dyn(input, nulls) +} + +/// Replaces the nulls in the input array with the given `NullBuffer` +/// +/// TODO: replace when upstreamed in arrow-rs: +pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result { + if let Some(nulls) = nulls.as_ref() { + assert_eq!(nulls.len(), input.len()); + } + + let output: ArrayRef = match input.data_type() { + DataType::Utf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeUtf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeStringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::Utf8View => { + let input = input.as_string_view(); + // safety: values / views came from a valid string view array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + + DataType::Binary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeBinary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid large binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeBinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::BinaryView => { + let input = input.as_binary_view(); + // safety: values / views came from a valid binary view array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + _ => { + return not_impl_err!("Applying nulls {:?}", input.data_type()); + } + }; + assert_eq!(input.len(), output.len()); + assert_eq!(input.data_type(), output.data_type()); + + Ok(output) +} diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 3d2915fd09cb..2f7954a8ee02 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -17,6 +17,8 @@ //! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function //! [`Min`] and [`MinAccumulator`] accumulator for the `min` function +mod min_max_bytes; + use arrow::array::{ ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, @@ -50,6 +52,7 @@ use arrow::datatypes::{ TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; +use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; use datafusion_common::ScalarValue; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, Signature, @@ -104,7 +107,7 @@ impl Default for Max { /// the specified [`ArrowPrimitiveType`]. /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_max_accumulator { +macro_rules! primitive_max_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { @@ -123,7 +126,7 @@ macro_rules! instantiate_max_accumulator { /// /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_min_accumulator { +macro_rules! primitive_min_accumulator { ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ Ok(Box::new( PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { @@ -231,6 +234,12 @@ impl AggregateUDFImpl for Max { | Time32(_) | Time64(_) | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView ) } @@ -242,58 +251,58 @@ impl AggregateUDFImpl for Max { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_max_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_max_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_max_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), + Int8 => primitive_max_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_max_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_max_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_max_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_max_accumulator!(data_type, f16, Float16Type) + primitive_max_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_max_accumulator!(data_type, f32, Float32Type) + primitive_max_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_max_accumulator!(data_type, f64, Float64Type) + primitive_max_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_max_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_max_accumulator!(data_type, i64, Date64Type), + Date32 => primitive_max_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_max_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_max_accumulator!(data_type, i32, Time32SecondType) + primitive_max_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_max_accumulator!(data_type, i32, Time32MillisecondType) + primitive_max_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_max_accumulator!(data_type, i64, Time64MicrosecondType) + primitive_max_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_max_accumulator!(data_type, i64, Time64NanosecondType) + primitive_max_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampSecondType) + primitive_max_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampMillisecondType) + primitive_max_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampMicrosecondType) + primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_max_accumulator!(data_type, i64, TimestampNanosecondType) + primitive_max_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_max_accumulator!(data_type, i128, Decimal128Type) + primitive_max_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_max_accumulator!(data_type, i256, Decimal256Type) + primitive_max_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) } - - // It would be nice to have a fast implementation for Strings as well - // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), @@ -1057,6 +1066,12 @@ impl AggregateUDFImpl for Min { | Time32(_) | Time64(_) | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView ) } @@ -1068,58 +1083,58 @@ impl AggregateUDFImpl for Min { use TimeUnit::*; let data_type = args.return_type; match data_type { - Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type), - Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type), - Int32 => instantiate_min_accumulator!(data_type, i32, Int32Type), - Int64 => instantiate_min_accumulator!(data_type, i64, Int64Type), - UInt8 => instantiate_min_accumulator!(data_type, u8, UInt8Type), - UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type), - UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type), - UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type), + Int8 => primitive_min_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_min_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_min_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_min_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type), Float16 => { - instantiate_min_accumulator!(data_type, f16, Float16Type) + primitive_min_accumulator!(data_type, f16, Float16Type) } Float32 => { - instantiate_min_accumulator!(data_type, f32, Float32Type) + primitive_min_accumulator!(data_type, f32, Float32Type) } Float64 => { - instantiate_min_accumulator!(data_type, f64, Float64Type) + primitive_min_accumulator!(data_type, f64, Float64Type) } - Date32 => instantiate_min_accumulator!(data_type, i32, Date32Type), - Date64 => instantiate_min_accumulator!(data_type, i64, Date64Type), + Date32 => primitive_min_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_min_accumulator!(data_type, i64, Date64Type), Time32(Second) => { - instantiate_min_accumulator!(data_type, i32, Time32SecondType) + primitive_min_accumulator!(data_type, i32, Time32SecondType) } Time32(Millisecond) => { - instantiate_min_accumulator!(data_type, i32, Time32MillisecondType) + primitive_min_accumulator!(data_type, i32, Time32MillisecondType) } Time64(Microsecond) => { - instantiate_min_accumulator!(data_type, i64, Time64MicrosecondType) + primitive_min_accumulator!(data_type, i64, Time64MicrosecondType) } Time64(Nanosecond) => { - instantiate_min_accumulator!(data_type, i64, Time64NanosecondType) + primitive_min_accumulator!(data_type, i64, Time64NanosecondType) } Timestamp(Second, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampSecondType) + primitive_min_accumulator!(data_type, i64, TimestampSecondType) } Timestamp(Millisecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampMillisecondType) + primitive_min_accumulator!(data_type, i64, TimestampMillisecondType) } Timestamp(Microsecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampMicrosecondType) + primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType) } Timestamp(Nanosecond, _) => { - instantiate_min_accumulator!(data_type, i64, TimestampNanosecondType) + primitive_min_accumulator!(data_type, i64, TimestampNanosecondType) } Decimal128(_, _) => { - instantiate_min_accumulator!(data_type, i128, Decimal128Type) + primitive_min_accumulator!(data_type, i128, Decimal128Type) } Decimal256(_, _) => { - instantiate_min_accumulator!(data_type, i256, Decimal256Type) + primitive_min_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) } - - // It would be nice to have a fast implementation for Strings as well - // https://github.com/apache/datafusion/issues/6906 // This is only reached if groups_accumulator_supported is out of sync _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs new file mode 100644 index 000000000000..e3f01b91bf3e --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -0,0 +1,515 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray, + LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder, +}; +use arrow_schema::DataType; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; +use std::sync::Arc; + +/// Implements fast Min/Max [`GroupsAccumulator`] for "bytes" types ([`StringArray`], +/// [`BinaryArray`], [`StringViewArray`], etc) +/// +/// This implementation dispatches to the appropriate specialized code in +/// [`MinMaxBytesState`] based on data type and comparison function +/// +/// [`StringArray`]: arrow::array::StringArray +/// [`BinaryArray`]: arrow::array::BinaryArray +/// [`StringViewArray`]: arrow::array::StringViewArray +#[derive(Debug)] +pub(crate) struct MinMaxBytesAccumulator { + /// Inner data storage. + inner: MinMaxBytesState, + /// if true, is `MIN` otherwise is `MAX` + is_min: bool, +} + +impl MinMaxBytesAccumulator { + /// Create a new accumulator for computing `min(val)` + pub fn new_min(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: true, + } + } + + /// Create a new accumulator fo computing `max(val)` + pub fn new_max(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: false, + } + } +} + +impl GroupsAccumulator for MinMaxBytesAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.len(), group_indices.len()); + assert_eq!(array.data_type(), &self.inner.data_type); + + // apply filter if needed + let array = apply_filter_as_nulls(array, opt_filter)?; + + // dispatch to appropriate kernel / specialized implementation + fn string_min(a: &[u8], b: &[u8]) -> bool { + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a < b + } + } + fn string_max(a: &[u8], b: &[u8]) -> bool { + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a > b + } + } + fn binary_min(a: &[u8], b: &[u8]) -> bool { + a < b + } + + fn binary_max(a: &[u8], b: &[u8]) -> bool { + a > b + } + + fn str_to_bytes<'a>( + it: impl Iterator>, + ) -> impl Iterator> { + it.map(|s| s.map(|s| s.as_bytes())) + } + + match (self.is_min, &self.inner.data_type) { + // Utf8/LargeUtf8/Utf8View Min + (true, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_min, + ), + (true, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_min, + ), + (true, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + total_num_groups, + string_min, + ), + + // Utf8/LargeUtf8/Utf8View Max + (false, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_max, + ), + (false, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_max, + ), + (false, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + total_num_groups, + string_max, + ), + + // Binary/LargeBinary/BinaryView Min + (true, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_min, + ), + (true, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_min, + ), + (true, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + total_num_groups, + binary_min, + ), + + // Binary/LargeBinary/BinaryView Max + (false, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_max, + ), + (false, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_max, + ), + (false, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + total_num_groups, + binary_max, + ), + + _ => internal_err!( + "Unexpected combination for MinMaxBytesAccumulator: ({:?}, {:?})", + self.is_min, + self.inner.data_type + ), + } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (data_capacity, min_maxes) = self.inner.emit_to(emit_to); + + // Convert the Vec of bytes to a vec of Strings (at no cost) + fn bytes_to_str( + min_maxes: Vec>>, + ) -> impl Iterator> { + min_maxes.into_iter().map(|opt| { + opt.map(|bytes| { + // Safety: only called on data added from update_batch which ensures + // the input type matched the output type + unsafe { String::from_utf8_unchecked(bytes) } + }) + }) + } + + let result: ArrayRef = match self.inner.data_type { + DataType::Utf8 => { + let mut builder = + StringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::LargeUtf8 => { + let mut builder = + LargeStringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Utf8View => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = StringViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Binary => { + let mut builder = + BinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::LargeBinary => { + let mut builder = + LargeBinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::BinaryView => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = BinaryViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + _ => { + return internal_err!( + "Unexpected data type for MinMaxBytesAccumulator: {:?}", + self.inner.data_type + ); + } + }; + + assert_eq!(&self.inner.data_type, result.data_type()); + Ok(result) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // min/max are their own states (no transition needed) + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // Min/max do not change the values as they are their own states + // apply the filter by combining with the null mask, if any + let output = apply_filter_as_nulls(&values[0], opt_filter)?; + Ok(vec![output]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +/// Returns the block size in (contiguous buffer size) to use +/// for a given data capacity (total string length) +/// +/// This is a heuristic to avoid allocating too many small buffers +fn capacity_to_view_block_size(data_capacity: usize) -> u32 { + let max_block_size = 2 * 1024 * 1024; + if let Ok(block_size) = u32::try_from(data_capacity) { + block_size.min(max_block_size) + } else { + max_block_size + } +} + +/// Stores internal Min/Max state for "bytes" types. +/// +/// This implementation is general and stores the minimum/maximum for each +/// groups in an individual byte array, which balances allocations and memory +/// fragmentation (aka garbage). +/// +/// ```text +/// ┌─────────────────────────────────┐ +/// ┌─────┐ ┌────▶│Option> (["A"]) │───────────▶ "A" +/// │ 0 │────┘ └─────────────────────────────────┘ +/// ├─────┤ ┌─────────────────────────────────┐ +/// │ 1 │─────────▶│Option> (["Z"]) │───────────▶ "Z" +/// └─────┘ └─────────────────────────────────┘ ... +/// ... ... +/// ┌─────┐ ┌────────────────────────────────┐ +/// │ N-2 │─────────▶│Option> (["A"]) │────────────▶ "A" +/// ├─────┤ └────────────────────────────────┘ +/// │ N-1 │────┐ ┌────────────────────────────────┐ +/// └─────┘ └────▶│Option> (["Q"]) │────────────▶ "Q" +/// └────────────────────────────────┘ +/// +/// min_max: Vec> +/// ``` +/// +/// Note that for `StringViewArray` and `BinaryViewArray`, there are potentially +/// more efficient implementations (e.g. by managing a string data buffer +/// directly), but then garbage collection, memory management, and final array +/// construction becomes more complex. +/// +/// See discussion on +#[derive(Debug)] +struct MinMaxBytesState { + /// The minimum/maximum value for each group + min_max: Vec>>, + /// The data type of the array + data_type: DataType, + /// The total bytes of the string data (for pre-allocating the final array, + /// and tracking memory usage) + total_data_bytes: usize, +} + +#[derive(Debug, Clone, Copy)] +enum MinMaxLocation<'a> { + /// the min/max value is stored in the existing `min_max` array + ExistingMinMax, + /// the min/max value is stored in the input array at the given index + Input(&'a [u8]), +} + +/// Implement the MinMaxBytesAccumulator with a comparison function +/// for comparing strings +impl MinMaxBytesState { + /// Create a new MinMaxBytesAccumulator + /// + /// # Arguments: + /// * `data_type`: The data type of the arrays that will be passed to this accumulator + fn new(data_type: DataType) -> Self { + Self { + min_max: vec![], + data_type, + total_data_bytes: 0, + } + } + + /// Set the specified group to the given value, updating memory usage appropriately + fn set_value(&mut self, group_index: usize, new_val: &[u8]) { + match self.min_max[group_index].as_mut() { + None => { + self.min_max[group_index] = Some(new_val.to_vec()); + self.total_data_bytes += new_val.len(); + } + Some(existing_val) => { + // Copy data over to avoid re-allocating + self.total_data_bytes -= existing_val.len(); + self.total_data_bytes += new_val.len(); + existing_val.clear(); + existing_val.extend_from_slice(new_val); + } + } + } + + /// Updates the min/max values for the given string values + /// + /// `cmp` is the comparison function to use, called like `cmp(new_val, existing_val)` + /// returns true if the `new_val` should replace `existing_val` + fn update_batch<'a, F, I>( + &mut self, + iter: I, + group_indices: &[usize], + total_num_groups: usize, + mut cmp: F, + ) -> Result<()> + where + F: FnMut(&[u8], &[u8]) -> bool + Send + Sync, + I: IntoIterator>, + { + self.min_max.resize(total_num_groups, None); + // Minimize value copies by calculating the new min/maxes for each group + // in this batch (either the existing min/max or the new input value) + // and updating the owne values in `self.min_maxes` at most once + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + + // Figure out the new min value for each group + for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) { + let group_index = *group_index; + let Some(new_val) = new_val else { + continue; // skip nulls + }; + + let existing_val = match locations[group_index] { + // previous input value was the min/max, so compare it + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(exising_val) = self.min_max[group_index].as_ref() else { + // no existing min/max, so this is the new min/max + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + exising_val.as_ref() + } + }; + + // Compare the new value to the existing value, replacing if necessary + if cmp(new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + // Update self.min_max with any new min/max values we found in the input + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), + } + } + Ok(()) + } + + /// Emits the specified min_max values + /// + /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes + /// + /// - `data_capacity`: the total length of all strings and their contents, + /// - `min_maxes`: the actual min/max values for each group + fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec>>) { + match emit_to { + EmitTo::All => { + ( + std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max + std::mem::take(&mut self.min_max), + ) + } + EmitTo::First(n) => { + let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_data_capacity: usize = first_min_maxes + .iter() + .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum(); + self.total_data_bytes -= first_data_capacity; + (first_data_capacity, first_min_maxes) + } + } + } + + fn size(&self) -> usize { + self.total_data_bytes + + self.min_max.len() * std::mem::size_of::>>() + } +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index ce382a9bf8d2..f03c3700ab9f 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3818,6 +3818,180 @@ DROP TABLE min_bool; # Min_Max End # ################# + + +################# +# min_max on strings/binary with null values and groups +################# + +statement ok +CREATE TABLE strings (value TEXT, id int); + +statement ok +INSERT INTO strings VALUES + ('c', 1), + ('d', 1), + ('a', 3), + ('c', 1), + ('b', 1), + (NULL, 1), + (NULL, 4), + ('d', 1), + ('z', 2), + ('c', 1), + ('a', 2); + +############ Utf8 ############ + +query IT +SELECT id, MIN(value) FROM strings GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM strings GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +############ LargeUtf8 ############ + +statement ok +CREATE VIEW large_strings AS SELECT id, arrow_cast(value, 'LargeUtf8') as value FROM strings; + + +query IT +SELECT id, MIN(value) FROM large_strings GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM large_strings GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +statement ok +DROP VIEW large_strings + +############ Utf8View ############ + +statement ok +CREATE VIEW string_views AS SELECT id, arrow_cast(value, 'Utf8View') as value FROM strings; + + +query IT +SELECT id, MIN(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +statement ok +DROP VIEW string_views + +############ Binary ############ + +statement ok +CREATE VIEW binary AS SELECT id, arrow_cast(value, 'Binary') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM binary GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM binary GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW binary + +############ LargeBinary ############ + +statement ok +CREATE VIEW large_binary AS SELECT id, arrow_cast(value, 'LargeBinary') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM large_binary GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM large_binary GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW large_binary + +############ BinaryView ############ + +statement ok +CREATE VIEW binary_views AS SELECT id, arrow_cast(value, 'BinaryView') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM binary_views GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM binary_views GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW binary_views + +statement ok +DROP TABLE strings; + +################# +# End min_max on strings/binary with null values and groups +################# + + statement ok create table bool_aggregate_functions ( c1 boolean not null, From 1b10c9f89eac127507fe7a137ff7c40534f7ca9a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 13 Oct 2024 07:26:42 -0500 Subject: [PATCH 04/19] Make PruningPredicate's rewrite public (#12850) * Make PruningPredicate's rewrite public * feedback * Improve documentation and add default to ConstantUnhandledPredicatehook * Update pruning.rs --------- Co-authored-by: Andrew Lamb --- .../core/src/physical_optimizer/pruning.rs | 212 ++++++++++++++++-- 1 file changed, 188 insertions(+), 24 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 9bc2bb1d1db9..eb03b337779c 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -458,7 +458,7 @@ pub trait PruningStatistics { /// [`Snowflake SIGMOD Paper`]: https://dl.acm.org/doi/10.1145/2882903.2903741 /// [small materialized aggregates]: https://www.vldb.org/conf/1998/p476.pdf /// [zone maps]: https://dl.acm.org/doi/10.1007/978-3-642-03730-6_10 -///[data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 +/// [data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 #[derive(Debug, Clone)] pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated @@ -478,6 +478,36 @@ pub struct PruningPredicate { literal_guarantees: Vec, } +/// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain +/// complex expressions or predicates that reference columns that are not in the +/// schema. +pub trait UnhandledPredicateHook { + /// Called when a predicate can not be rewritten in terms of statistics or + /// references a column that is not in the schema. + fn handle(&self, expr: &Arc) -> Arc; +} + +/// The default handling for unhandled predicates is to return a constant `true` +/// (meaning don't prune the container) +#[derive(Debug, Clone)] +struct ConstantUnhandledPredicateHook { + default: Arc, +} + +impl Default for ConstantUnhandledPredicateHook { + fn default() -> Self { + Self { + default: Arc::new(phys_expr::Literal::new(ScalarValue::from(true))), + } + } +} + +impl UnhandledPredicateHook for ConstantUnhandledPredicateHook { + fn handle(&self, _expr: &Arc) -> Arc { + self.default.clone() + } +} + impl PruningPredicate { /// Try to create a new instance of [`PruningPredicate`] /// @@ -502,10 +532,16 @@ impl PruningPredicate { /// See the struct level documentation on [`PruningPredicate`] for more /// details. pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; + // build predicate expression once let mut required_columns = RequiredColumns::new(); - let predicate_expr = - build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + let predicate_expr = build_predicate_expression( + &expr, + schema.as_ref(), + &mut required_columns, + &unhandled_hook, + ); let literal_guarantees = LiteralGuarantee::analyze(&expr); @@ -1312,27 +1348,78 @@ fn build_is_null_column_expr( /// an OR chain const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20; +/// Rewrite a predicate expression in terms of statistics (min/max/null_counts) +/// for use as a [`PruningPredicate`]. +pub struct PredicateRewriter { + unhandled_hook: Arc, +} + +impl Default for PredicateRewriter { + fn default() -> Self { + Self { + unhandled_hook: Arc::new(ConstantUnhandledPredicateHook::default()), + } + } +} + +impl PredicateRewriter { + /// Create a new `PredicateRewriter` + pub fn new() -> Self { + Self::default() + } + + /// Set the unhandled hook to be used when a predicate can not be rewritten + pub fn with_unhandled_hook( + self, + unhandled_hook: Arc, + ) -> Self { + Self { unhandled_hook } + } + + /// Translate logical filter expression into pruning predicate + /// expression that will evaluate to FALSE if it can be determined no + /// rows between the min/max values could pass the predicates. + /// + /// Any predicates that can not be translated will be passed to `unhandled_hook`. + /// + /// Returns the pruning predicate as an [`PhysicalExpr`] + /// + /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` + pub fn rewrite_predicate_to_statistics_predicate( + &self, + expr: &Arc, + schema: &Schema, + ) -> Arc { + let mut required_columns = RequiredColumns::new(); + build_predicate_expression( + expr, + schema, + &mut required_columns, + &self.unhandled_hook, + ) + } +} + /// Translate logical filter expression into pruning predicate /// expression that will evaluate to FALSE if it can be determined no /// rows between the min/max values could pass the predicates. /// +/// Any predicates that can not be translated will be passed to `unhandled_hook`. +/// /// Returns the pruning predicate as an [`PhysicalExpr`] /// -/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE +/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` fn build_predicate_expression( expr: &Arc, schema: &Schema, required_columns: &mut RequiredColumns, + unhandled_hook: &Arc, ) -> Arc { - // Returned for unsupported expressions. Such expressions are - // converted to TRUE. - let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))); - // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { return build_is_null_column_expr(is_null.arg(), schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(is_not_null) = expr_any.downcast_ref::() { return build_is_null_column_expr( @@ -1341,19 +1428,19 @@ fn build_predicate_expression( required_columns, true, ) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(col) = expr_any.downcast_ref::() { return build_single_column_expr(col, schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(not) = expr_any.downcast_ref::() { // match !col (don't do so recursively) if let Some(col) = not.arg().as_any().downcast_ref::() { return build_single_column_expr(col, schema, required_columns, true) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } else { - return unhandled; + return unhandled_hook.handle(expr); } } if let Some(in_list) = expr_any.downcast_ref::() { @@ -1382,9 +1469,14 @@ fn build_predicate_expression( }) .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _) .unwrap(); - return build_predicate_expression(&change_expr, schema, required_columns); + return build_predicate_expression( + &change_expr, + schema, + required_columns, + unhandled_hook, + ); } else { - return unhandled; + return unhandled_hook.handle(expr); } } @@ -1396,13 +1488,15 @@ fn build_predicate_expression( bin_expr.right().clone(), ) } else { - return unhandled; + return unhandled_hook.handle(expr); } }; if op == Operator::And || op == Operator::Or { - let left_expr = build_predicate_expression(&left, schema, required_columns); - let right_expr = build_predicate_expression(&right, schema, required_columns); + let left_expr = + build_predicate_expression(&left, schema, required_columns, unhandled_hook); + let right_expr = + build_predicate_expression(&right, schema, required_columns, unhandled_hook); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { (left, Operator::And, _) if is_always_true(left) => right_expr, @@ -1410,7 +1504,7 @@ fn build_predicate_expression( (left, Operator::Or, right) if is_always_true(left) || is_always_true(right) => { - unhandled + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) } _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; @@ -1423,12 +1517,11 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => { - return unhandled; - } + Err(_) => return unhandled_hook.handle(expr), }; - build_statistics_expr(&mut expr_builder).unwrap_or(unhandled) + build_statistics_expr(&mut expr_builder) + .unwrap_or_else(|_| unhandled_hook.handle(expr)) } fn build_statistics_expr( @@ -1582,6 +1675,8 @@ mod tests { use arrow_array::UInt64Array; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; + use datafusion_functions_nested::expr_fn::{array_has, make_array}; + use datafusion_physical_expr::expressions as phys_expr; use datafusion_physical_expr::planner::logical2physical; #[derive(Debug, Default)] @@ -3397,6 +3492,74 @@ mod tests { // TODO: add test for other case and op } + #[test] + fn test_rewrite_expr_to_prunable_custom_unhandled_hook() { + struct CustomUnhandledHook; + + impl UnhandledPredicateHook for CustomUnhandledHook { + /// This handles an arbitrary case of a column that doesn't exist in the schema + /// by renaming it to yet another column that doesn't exist in the schema + /// (the transformation is arbitrary, the point is that it can do whatever it wants) + fn handle(&self, _expr: &Arc) -> Arc { + Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42)))) + } + } + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let schema_with_b = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let rewriter = PredicateRewriter::new() + .with_unhandled_hook(Arc::new(CustomUnhandledHook {})); + + let transform_expr = |expr| { + let expr = logical2physical(&expr, &schema_with_b); + rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema) + }; + + // transform an arbitrary valid expression that we know is handled + let known_expression = col("a").eq(lit(12)); + let known_expression_transformed = PredicateRewriter::new() + .rewrite_predicate_to_statistics_predicate( + &logical2physical(&known_expression, &schema), + &schema, + ); + + // an expression referencing an unknown column (that is not in the schema) gets passed to the hook + let input = col("b").eq(lit(12)); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown column + let input = known_expression.clone().and(input.clone()); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // an unknown expression gets passed to the hook + let input = array_has(make_array(vec![lit(1)]), col("a")); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown expression + let input = known_expression.and(input); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + } + #[test] fn test_rewrite_expr_to_prunable_error() { // cast string value to numeric value @@ -3886,6 +4049,7 @@ mod tests { required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); - build_predicate_expression(&expr, schema, required_columns) + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; + build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } } From 636f43321acfd295096ad3ec45ef00595203f3f7 Mon Sep 17 00:00:00 2001 From: Haile <52631736+hailelagi@users.noreply.github.com> Date: Mon, 14 Oct 2024 03:21:51 +0100 Subject: [PATCH 05/19] Minor: add flags for temporary ddl (#12561) * Minor: add flags for temporary ddl Signed-off-by: Haile Lagi <52631736+hailelagi@users.noreply.github.com> * Update datafusion/proto/src/logical_plan/mod.rs Co-authored-by: Jonah Gao --------- Signed-off-by: Haile Lagi <52631736+hailelagi@users.noreply.github.com> Co-authored-by: Jonah Gao --- .../core/src/catalog_common/listing_schema.rs | 1 + .../src/datasource/listing_table_factory.rs | 2 ++ datafusion/core/src/execution/context/mod.rs | 16 +++++++++ datafusion/expr/src/logical_plan/ddl.rs | 6 ++++ datafusion/expr/src/logical_plan/plan.rs | 4 +++ datafusion/expr/src/logical_plan/tree_node.rs | 4 +++ datafusion/proto/proto/datafusion.proto | 2 ++ datafusion/proto/src/generated/pbjson.rs | 34 +++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 4 +++ datafusion/proto/src/logical_plan/mod.rs | 6 ++++ datafusion/sql/src/parser.rs | 24 +++++++++++++ datafusion/sql/src/query.rs | 1 + datafusion/sql/src/statement.rs | 8 +++-- .../test_files/create_external_table.slt | 3 -- datafusion/sqllogictest/test_files/ddl.slt | 24 +++++++++++++ 15 files changed, 133 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/catalog_common/listing_schema.rs b/datafusion/core/src/catalog_common/listing_schema.rs index e45c8a8d4aeb..665ea58c5f75 100644 --- a/datafusion/core/src/catalog_common/listing_schema.rs +++ b/datafusion/core/src/catalog_common/listing_schema.rs @@ -136,6 +136,7 @@ impl ListingSchemaProvider { file_type: self.format.clone(), table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index fed63ec12b49..701a13477b5b 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -197,6 +197,7 @@ mod tests { schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, @@ -236,6 +237,7 @@ mod tests { schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index b0951d9ec44c..606759aae5ee 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -738,6 +738,11 @@ impl SessionContext { cmd: &CreateExternalTable, ) -> Result { let exist = self.table_exist(cmd.name.clone())?; + + if cmd.temporary { + return not_impl_err!("Temporary tables not supported"); + } + if exist { match cmd.if_not_exists { true => return self.return_empty_dataframe(), @@ -761,10 +766,16 @@ impl SessionContext { or_replace, constraints, column_defaults, + temporary, } = cmd; let input = Arc::unwrap_or_clone(input); let input = self.state().optimize(&input)?; + + if temporary { + return not_impl_err!("Temporary tables not supported"); + } + let table = self.table(name.clone()).await; match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), @@ -813,10 +824,15 @@ impl SessionContext { input, or_replace, definition, + temporary, } = cmd; let view = self.table(name.clone()).await; + if temporary { + return not_impl_err!("Temporary views not supported"); + } + match (or_replace, view) { (true, Ok(_)) => { self.deregister_table(name.clone())?; diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 9aaa5c98037a..c4fa9f4c3fed 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -202,6 +202,8 @@ pub struct CreateExternalTable { pub table_partition_cols: Vec, /// Option to not error if table already exists pub if_not_exists: bool, + /// Whether the table is a temporary table + pub temporary: bool, /// SQL used to create the table, if available pub definition: Option, /// Order expressions supplied by user @@ -298,6 +300,8 @@ pub struct CreateMemoryTable { pub or_replace: bool, /// Default values for columns pub column_defaults: Vec<(String, Expr)>, + /// Wheter the table is `TableType::Temporary` + pub temporary: bool, } /// Creates a view. @@ -311,6 +315,8 @@ pub struct CreateView { pub or_replace: bool, /// SQL used to create the view, if available pub definition: Option, + /// Wheter the view is ephemeral + pub temporary: bool, } /// Creates a catalog (aka "Database"). diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 0292274e57ee..9bd57d22128d 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -993,6 +993,7 @@ impl LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, .. })) => { self.assert_no_expressions(expr)?; @@ -1005,6 +1006,7 @@ impl LogicalPlan { if_not_exists: *if_not_exists, or_replace: *or_replace, column_defaults: column_defaults.clone(), + temporary: *temporary, }, ))) } @@ -1012,6 +1014,7 @@ impl LogicalPlan { name, or_replace, definition, + temporary, .. })) => { self.assert_no_expressions(expr)?; @@ -1020,6 +1023,7 @@ impl LogicalPlan { input: Arc::new(input), name: name.clone(), or_replace: *or_replace, + temporary: *temporary, definition: definition.clone(), }))) } diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 8ba68697bd4d..83206a2b2af5 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -285,6 +285,7 @@ impl TreeNode for LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, }) => rewrite_arc(input, f)?.update_data(|input| { DdlStatement::CreateMemoryTable(CreateMemoryTable { name, @@ -293,6 +294,7 @@ impl TreeNode for LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, }) }), DdlStatement::CreateView(CreateView { @@ -300,12 +302,14 @@ impl TreeNode for LogicalPlan { input, or_replace, definition, + temporary, }) => rewrite_arc(input, f)?.update_data(|input| { DdlStatement::CreateView(CreateView { name, input, or_replace, definition, + temporary, }) }), // no inputs in these statements diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 803cb49919ee..5256f7473c95 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -163,6 +163,7 @@ message CreateExternalTableNode { datafusion_common.DfSchema schema = 4; repeated string table_partition_cols = 5; bool if_not_exists = 6; + bool temporary = 14; string definition = 7; repeated SortExprNodeCollection order_exprs = 10; bool unbounded = 11; @@ -200,6 +201,7 @@ message CreateViewNode { TableReference name = 5; LogicalPlanNode input = 2; bool or_replace = 3; + bool temporary = 6; string definition = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index c7d4c4561a1b..e876008e853f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -3193,6 +3193,9 @@ impl serde::Serialize for CreateExternalTableNode { if self.if_not_exists { len += 1; } + if self.temporary { + len += 1; + } if !self.definition.is_empty() { len += 1; } @@ -3230,6 +3233,9 @@ impl serde::Serialize for CreateExternalTableNode { if self.if_not_exists { struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; } + if self.temporary { + struct_ser.serialize_field("temporary", &self.temporary)?; + } if !self.definition.is_empty() { struct_ser.serialize_field("definition", &self.definition)?; } @@ -3267,6 +3273,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "tablePartitionCols", "if_not_exists", "ifNotExists", + "temporary", "definition", "order_exprs", "orderExprs", @@ -3285,6 +3292,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { Schema, TablePartitionCols, IfNotExists, + Temporary, Definition, OrderExprs, Unbounded, @@ -3318,6 +3326,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { "schema" => Ok(GeneratedField::Schema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "temporary" => Ok(GeneratedField::Temporary), "definition" => Ok(GeneratedField::Definition), "orderExprs" | "order_exprs" => Ok(GeneratedField::OrderExprs), "unbounded" => Ok(GeneratedField::Unbounded), @@ -3349,6 +3358,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { let mut schema__ = None; let mut table_partition_cols__ = None; let mut if_not_exists__ = None; + let mut temporary__ = None; let mut definition__ = None; let mut order_exprs__ = None; let mut unbounded__ = None; @@ -3393,6 +3403,12 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } if_not_exists__ = Some(map_.next_value()?); } + GeneratedField::Temporary => { + if temporary__.is_some() { + return Err(serde::de::Error::duplicate_field("temporary")); + } + temporary__ = Some(map_.next_value()?); + } GeneratedField::Definition => { if definition__.is_some() { return Err(serde::de::Error::duplicate_field("definition")); @@ -3442,6 +3458,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { schema: schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), if_not_exists: if_not_exists__.unwrap_or_default(), + temporary: temporary__.unwrap_or_default(), definition: definition__.unwrap_or_default(), order_exprs: order_exprs__.unwrap_or_default(), unbounded: unbounded__.unwrap_or_default(), @@ -3471,6 +3488,9 @@ impl serde::Serialize for CreateViewNode { if self.or_replace { len += 1; } + if self.temporary { + len += 1; + } if !self.definition.is_empty() { len += 1; } @@ -3484,6 +3504,9 @@ impl serde::Serialize for CreateViewNode { if self.or_replace { struct_ser.serialize_field("orReplace", &self.or_replace)?; } + if self.temporary { + struct_ser.serialize_field("temporary", &self.temporary)?; + } if !self.definition.is_empty() { struct_ser.serialize_field("definition", &self.definition)?; } @@ -3501,6 +3524,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { "input", "or_replace", "orReplace", + "temporary", "definition", ]; @@ -3509,6 +3533,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { Name, Input, OrReplace, + Temporary, Definition, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -3534,6 +3559,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { "name" => Ok(GeneratedField::Name), "input" => Ok(GeneratedField::Input), "orReplace" | "or_replace" => Ok(GeneratedField::OrReplace), + "temporary" => Ok(GeneratedField::Temporary), "definition" => Ok(GeneratedField::Definition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -3557,6 +3583,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { let mut name__ = None; let mut input__ = None; let mut or_replace__ = None; + let mut temporary__ = None; let mut definition__ = None; while let Some(k) = map_.next_key()? { match k { @@ -3578,6 +3605,12 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { } or_replace__ = Some(map_.next_value()?); } + GeneratedField::Temporary => { + if temporary__.is_some() { + return Err(serde::de::Error::duplicate_field("temporary")); + } + temporary__ = Some(map_.next_value()?); + } GeneratedField::Definition => { if definition__.is_some() { return Err(serde::de::Error::duplicate_field("definition")); @@ -3590,6 +3623,7 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { name: name__, input: input__, or_replace: or_replace__.unwrap_or_default(), + temporary: temporary__.unwrap_or_default(), definition: definition__.unwrap_or_default(), }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8cba6f84f7eb..2aa14f7e80b0 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -240,6 +240,8 @@ pub struct CreateExternalTableNode { pub table_partition_cols: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, #[prost(bool, tag = "6")] pub if_not_exists: bool, + #[prost(bool, tag = "14")] + pub temporary: bool, #[prost(string, tag = "7")] pub definition: ::prost::alloc::string::String, #[prost(message, repeated, tag = "10")] @@ -303,6 +305,8 @@ pub struct CreateViewNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(bool, tag = "3")] pub or_replace: bool, + #[prost(bool, tag = "6")] + pub temporary: bool, #[prost(string, tag = "4")] pub definition: ::prost::alloc::string::String, } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 7156cee66aff..6061a7a0619a 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -579,6 +579,7 @@ impl AsLogicalPlan for LogicalPlanNode { .clone(), order_exprs, if_not_exists: create_extern_table.if_not_exists, + temporary: create_extern_table.temporary, definition, unbounded: create_extern_table.unbounded, options: create_extern_table.options.clone(), @@ -601,6 +602,7 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { name: from_table_reference(create_view.name.as_ref(), "CreateView")?, + temporary: create_view.temporary, input: Arc::new(plan), or_replace: create_view.or_replace, definition, @@ -1386,6 +1388,7 @@ impl AsLogicalPlan for LogicalPlanNode { options, constraints, column_defaults, + temporary, }, )) => { let mut converted_order_exprs: Vec = vec![]; @@ -1412,6 +1415,7 @@ impl AsLogicalPlan for LogicalPlanNode { schema: Some(df_schema.try_into()?), table_partition_cols: table_partition_cols.clone(), if_not_exists: *if_not_exists, + temporary: *temporary, order_exprs: converted_order_exprs, definition: definition.clone().unwrap_or_default(), unbounded: *unbounded, @@ -1427,6 +1431,7 @@ impl AsLogicalPlan for LogicalPlanNode { input, or_replace, definition, + temporary, })) => Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateView(Box::new( protobuf::CreateViewNode { @@ -1436,6 +1441,7 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec, )?)), or_replace: *or_replace, + temporary: *temporary, definition: definition.clone().unwrap_or_default(), }, ))), diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index a68d8491856d..8a984f1645e9 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -194,6 +194,8 @@ pub struct CreateExternalTable { pub order_exprs: Vec, /// Option to not error if table already exists pub if_not_exists: bool, + /// Whether the table is a temporary table + pub temporary: bool, /// Infinite streams? pub unbounded: bool, /// Table(provider) specific options @@ -699,6 +701,10 @@ impl<'a> DFParser<'a> { &mut self, unbounded: bool, ) -> Result { + let temporary = self + .parser + .parse_one_of_keywords(&[Keyword::TEMP, Keyword::TEMPORARY]) + .is_some(); self.parser.expect_keyword(Keyword::TABLE)?; let if_not_exists = self.parser @@ -820,6 +826,7 @@ impl<'a> DFParser<'a> { table_partition_cols: builder.table_partition_cols.unwrap_or(vec![]), order_exprs: builder.order_exprs, if_not_exists, + temporary, unbounded, options: builder.options.unwrap_or(Vec::new()), constraints, @@ -924,6 +931,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -940,6 +948,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -957,6 +966,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -974,6 +984,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![( "format.delimiter".into(), @@ -994,6 +1005,7 @@ mod tests { table_partition_cols: vec!["p1".to_string(), "p2".to_string()], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1021,6 +1033,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![( "format.compression".into(), @@ -1041,6 +1054,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1057,6 +1071,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1073,6 +1088,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1090,6 +1106,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: true, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1110,6 +1127,7 @@ mod tests { table_partition_cols: vec!["p1".to_string()], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1140,6 +1158,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![("k1".into(), Value::SingleQuotedString("v1".into()))], constraints: vec![], @@ -1157,6 +1176,7 @@ mod tests { table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, + temporary: false, unbounded: false, options: vec![ ("k1".into(), Value::SingleQuotedString("v1".into())), @@ -1204,6 +1224,7 @@ mod tests { with_fill: None, }]], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1244,6 +1265,7 @@ mod tests { }, ]], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1279,6 +1301,7 @@ mod tests { with_fill: None, }]], if_not_exists: false, + temporary: false, unbounded: false, options: vec![], constraints: vec![], @@ -1323,6 +1346,7 @@ mod tests { with_fill: None, }]], if_not_exists: true, + temporary: false, unbounded: true, options: vec![ ( diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 125259d2276f..54945ec43d10 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -150,6 +150,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), if_not_exists: false, or_replace: false, + temporary: false, column_defaults: vec![], }, ))), diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 3111fab9a2ff..edb4316db1e0 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -440,6 +440,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_not_exists, or_replace, column_defaults, + temporary, }, ))) } @@ -463,6 +464,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_not_exists, or_replace, column_defaults, + temporary, }, ))) } @@ -498,9 +500,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if if_not_exists { return not_impl_err!("If not exists not supported")?; } - if temporary { - return not_impl_err!("Temporary views not supported")?; - } if to.is_some() { return not_impl_err!("To not supported")?; } @@ -526,6 +525,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), or_replace, definition: sql, + temporary, }))) } Statement::ShowCreate { obj_type, obj_name } => match obj_type { @@ -1198,6 +1198,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { location, table_partition_cols, if_not_exists, + temporary, order_exprs, unbounded, options, @@ -1250,6 +1251,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { file_type, table_partition_cols, if_not_exists, + temporary, definition, order_exprs: ordered_exprs, unbounded, diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index 9ac2ecdce7cc..7dba4d01d63b 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -102,9 +102,6 @@ CREATE TEMPORARY TABLE my_temp_table ( name TEXT NOT NULL ); -statement error DataFusion error: This feature is not implemented: Temporary views not supported -CREATE TEMPORARY VIEW my_temp_view AS SELECT id, name FROM my_table; - # Partitioned table on a single file query error DataFusion error: Error during planning: Can't create a partitioned table backed by a single file, perhaps the URL is missing a trailing slash\? CREATE EXTERNAL TABLE single_file_partition(c1 int) diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index 21edb458fe56..813f7e95adf0 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -775,3 +775,27 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te statement ok drop table t; + +statement ok +set datafusion.explain.logical_plan_only=true; + +query TT +explain CREATE TEMPORARY VIEW z AS VALUES (1,2,3); +---- +logical_plan +01)CreateView: Bare { table: "z" } +02)--Values: (Int64(1), Int64(2), Int64(3)) + +query TT +explain CREATE EXTERNAL TEMPORARY TABLE tty STORED as ARROW LOCATION '../core/tests/data/example.arrow'; +---- +logical_plan CreateExternalTable: Bare { table: "tty" } + +statement ok +set datafusion.explain.logical_plan_only=false; + +statement error DataFusion error: This feature is not implemented: Temporary tables not supported +CREATE EXTERNAL TEMPORARY TABLE tty STORED as ARROW LOCATION '../core/tests/data/example.arrow'; + +statement error DataFusion error: This feature is not implemented: Temporary views not supported +CREATE TEMPORARY VIEW y AS VALUES (1,2,3); From 181d38c2cb72cff38eb2ff28c79aa0649a05dfd6 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sun, 13 Oct 2024 22:31:12 -0400 Subject: [PATCH 06/19] Adding test for verifying octet_length now works with string view (#12900) --- datafusion/sqllogictest/test_files/string/string_view.slt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 19bea3bf6bd0..b6dc696d271f 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -83,6 +83,14 @@ drop table test_source ## StringView Function test ######## +query I +select octet_length(column1_utf8view) from test; +---- +6 +9 +7 +NULL + query error DataFusion error: Arrow error: Compute error: bit_length not supported for Utf8View select bit_length(column1_utf8view) from test; From 849bbe75b446d38ac95cf09e457cdf4e92d4c9e2 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Mon, 14 Oct 2024 19:37:55 +0800 Subject: [PATCH 07/19] Remove Expr clones in `select_to_plan` (#12887) --- datafusion/expr/src/utils.rs | 9 ++++++--- datafusion/sql/src/select.rs | 13 ++++--------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 02b36d0feab9..06cf1ec693f0 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -600,7 +600,7 @@ pub fn group_window_expr_by_sort_keys( /// Collect all deeply nested `Expr::AggregateFunction`. /// They are returned in order of occurrence (depth /// first), with duplicates omitted. -pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { +pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { matches!(nested_expr, Expr::AggregateFunction { .. }) }) @@ -625,12 +625,15 @@ pub fn find_out_reference_exprs(expr: &Expr) -> Vec { /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that /// pass the provided test. The returned `Expr`'s are deduplicated and returned /// in order of appearance (depth first). -fn find_exprs_in_exprs(exprs: &[Expr], test_fn: &F) -> Vec +fn find_exprs_in_exprs<'a, F>( + exprs: impl IntoIterator, + test_fn: &F, +) -> Vec where F: Fn(&Expr) -> bool, { exprs - .iter() + .into_iter() .flat_map(|expr| find_exprs_in_expr(expr, test_fn)) .fold(vec![], |mut acc, expr| { if !acc.contains(&expr) { diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 69c7745165f4..c665dec21df4 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -137,16 +137,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .transpose()?; - // The outer expressions we will search through for - // aggregates. Aggregates may be sourced from the SELECT... - let mut aggr_expr_haystack = select_exprs.clone(); - // ... or from the HAVING. - if let Some(having_expr) = &having_expr_opt { - aggr_expr_haystack.push(having_expr.clone()); - } - + // The outer expressions we will search through for aggregates. + // Aggregates may be sourced from the SELECT list or from the HAVING expression. + let aggr_expr_haystack = select_exprs.iter().chain(having_expr_opt.iter()); // All of the aggregate expressions (deduplicated). - let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); + let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack); // All of the group by expressions let group_by_exprs = if let GroupByExpr::Expressions(exprs, _) = select.group_by { From b932cdb3b7dcff97fe05c6ea3976b7d924b10200 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Mon, 14 Oct 2024 07:41:08 -0400 Subject: [PATCH 08/19] Minor: added to docs in expr folder (#12882) * Documentation fixes + add * fmt fix * quick fix --- .../expr/src/built_in_window_function.rs | 30 ++++++++-------- .../expr/src/conditional_expressions.rs | 4 +-- datafusion/expr/src/expr.rs | 18 +++++----- datafusion/expr/src/expr_schema.rs | 35 +++++++++++-------- datafusion/expr/src/function.rs | 4 +-- datafusion/expr/src/simplify.rs | 8 ++--- datafusion/expr/src/tree_node.rs | 19 +++++++++- datafusion/expr/src/udaf.rs | 8 ++--- datafusion/expr/src/udf_docs.rs | 18 +++++----- datafusion/expr/src/utils.rs | 20 +++++------ datafusion/expr/src/window_state.rs | 2 +- 11 files changed, 95 insertions(+), 71 deletions(-) diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index 117ff08253b6..6a30080fb38b 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -37,28 +37,28 @@ impl fmt::Display for BuiltInWindowFunction { /// A [window function] built in to DataFusion /// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) +/// [Window Function]: https://en.wikipedia.org/wiki/Window_function_(SQL) #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum BuiltInWindowFunction { - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) + /// Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible + /// Integer ranging from 1 to the argument value, dividing the partition as equally as possible Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). + /// Returns value evaluated at the row that is offset rows before the current row within the partition; + /// If there is no such row, instead return default (which must be of the same type as value). /// Both offset and default are evaluated with respect to the current row. /// If omitted, offset defaults to 1 and default to null Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). + /// Returns value evaluated at the row that is offset rows after the current row within the partition; + /// If there is no such row, instead return default (which must be of the same type as value). /// Both offset and default are evaluated with respect to the current row. /// If omitted, offset defaults to 1 and default to null Lead, - /// returns value evaluated at the row that is the first row of the window frame + /// Returns value evaluated at the row that is the first row of the window frame FirstValue, - /// returns value evaluated at the row that is the last row of the window frame + /// Returns value evaluated at the row that is the last row of the window frame LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + /// Returns value evaluated at the row that is the nth row of the window frame (counting from 1); returns null if no such row NthValue, } @@ -99,10 +99,10 @@ impl BuiltInWindowFunction { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - // verify that this is a valid set of data types for this function + // Verify that this is a valid set of data types for this function data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message + // Original errors are all related to wrong function signature + // Aggregate them for better error message .map_err(|_| { plan_datafusion_err!( "{}", @@ -125,9 +125,9 @@ impl BuiltInWindowFunction { } } - /// the signatures supported by the built-in window function `fun`. + /// The signatures supported by the built-in window function `fun`. pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. + // Note: The physical expression must accept the type returned by this function or the execution panics. match self { BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 7a2bf4b6c44a..23cc88f1c0ff 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -64,7 +64,7 @@ impl CaseBuilder { } fn build(&self) -> Result { - // collect all "then" expressions + // Collect all "then" expressions let mut then_expr = self.then_expr.clone(); if let Some(e) = &self.else_expr { then_expr.push(e.as_ref().to_owned()); @@ -79,7 +79,7 @@ impl CaseBuilder { .collect::>>()?; if then_types.contains(&DataType::Null) { - // cannot verify types until execution type + // Cannot verify types until execution type } else { let unique_types: HashSet<&DataType> = then_types.iter().collect(); if unique_types.len() != 1 { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 723433f57341..3e692189e488 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -722,7 +722,7 @@ impl WindowFunctionDefinition { } } - /// the signatures supported by the function `fun`. + /// The signatures supported by the function `fun`. pub fn signature(&self) -> Signature { match self { WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), @@ -846,7 +846,7 @@ pub fn find_df_window_func(name: &str) -> Option { /// EXISTS expression #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Exists { - /// subquery that will produce a single column of data + /// Subquery that will produce a single column of data pub subquery: Subquery, /// Whether the expression is negated pub negated: bool, @@ -1329,7 +1329,7 @@ impl Expr { expr, Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) ) { - // subqueries could contain aliases so don't recurse into those + // Subqueries could contain aliases so don't recurse into those TreeNodeRecursion::Jump } else { TreeNodeRecursion::Continue @@ -1346,7 +1346,7 @@ impl Expr { } }, ) - // unreachable code: internal closure doesn't return err + // Unreachable code: internal closure doesn't return err .unwrap() } @@ -1416,7 +1416,7 @@ impl Expr { )) } - /// return `self NOT BETWEEN low AND high` + /// Return `self NOT BETWEEN low AND high` pub fn not_between(self, low: Expr, high: Expr) -> Expr { Expr::Between(Between::new( Box::new(self), @@ -1817,7 +1817,7 @@ impl Expr { } } -// modifies expr if it is a placeholder with datatype of right +// Modifies expr if it is a placeholder with datatype of right fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { if data_type.is_none() { @@ -1890,7 +1890,7 @@ impl<'a> Display for SchemaDisplay<'a> { Ok(()) } - // expr is not shown since it is aliased + // Expr is not shown since it is aliased Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), Expr::Between(Between { expr, @@ -1945,7 +1945,7 @@ impl<'a> Display for SchemaDisplay<'a> { write!(f, "END") } - // cast expr is not shown to be consistant with Postgres and Spark + // Cast expr is not shown to be consistant with Postgres and Spark Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => { write!(f, "{}", SchemaDisplay(expr)) } @@ -2415,7 +2415,7 @@ mod test { let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, expr.canonical_name()); assert_eq!(expected_canonical, format!("{expr}")); - // note that CAST intentionally has a name that is different from its `Display` + // Note that CAST intentionally has a name that is different from its `Display` // representation. CAST does not change the name of expressions. assert_eq!("Float32(1.23)", expr.schema_name().to_string()); Ok(()) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index ad617c53d617..07a36672f272 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -35,27 +35,27 @@ use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::collections::HashMap; use std::sync::Arc; -/// trait to allow expr to typable with respect to a schema +/// Trait to allow expr to typable with respect to a schema pub trait ExprSchemable { - /// given a schema, return the type of the expr + /// Given a schema, return the type of the expr fn get_type(&self, schema: &dyn ExprSchema) -> Result; - /// given a schema, return the nullability of the expr + /// Given a schema, return the nullability of the expr fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; - /// given a schema, return the expr's optional metadata + /// Given a schema, return the expr's optional metadata fn metadata(&self, schema: &dyn ExprSchema) -> Result>; - /// convert to a field with respect to a schema + /// Convert to a field with respect to a schema fn to_field( &self, input_schema: &dyn ExprSchema, ) -> Result<(Option, Arc)>; - /// cast to a type with respect to a schema + /// Cast to a type with respect to a schema fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; - /// given a schema, return the type and nullability of the expr + /// Given a schema, return the type and nullability of the expr fn data_type_and_nullable(&self, schema: &dyn ExprSchema) -> Result<(DataType, bool)>; } @@ -150,7 +150,7 @@ impl ExprSchemable for Expr { .map(|e| e.get_type(schema)) .collect::>>()?; - // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_data_types, func) .map_err(|err| { plan_datafusion_err!( @@ -164,7 +164,7 @@ impl ExprSchemable for Expr { ) })?; - // perform additional function arguments validation (due to limited + // Perform additional function arguments validation (due to limited // expressiveness of `TypeSignature`), then infer return type Ok(func.return_type_from_exprs(args, schema, &new_data_types)?) } @@ -223,7 +223,7 @@ impl ExprSchemable for Expr { } Expr::Wildcard { .. } => Ok(DataType::Null), Expr::GroupingSet(_) => { - // grouping sets do not really have a type and do not appear in projections + // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } } @@ -279,7 +279,7 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(_, _) => Ok(true), Expr::Literal(value) => Ok(value.is_null()), Expr::Case(case) => { - // this expression is nullable if any of the input expressions are nullable + // This expression is nullable if any of the input expressions are nullable let then_nullable = case .when_then_expr .iter() @@ -336,7 +336,7 @@ impl ExprSchemable for Expr { } Expr::Wildcard { .. } => Ok(false), Expr::GroupingSet(_) => { - // grouping sets do not really have the concept of nullable and do not appear + // Grouping sets do not really have the concept of nullable and do not appear // in projections Ok(true) } @@ -439,7 +439,7 @@ impl ExprSchemable for Expr { return Ok(self); } - // TODO(kszucs): most of the operations do not validate the type correctness + // TODO(kszucs): Most of the operations do not validate the type correctness // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? @@ -526,7 +526,14 @@ impl Expr { } } -/// cast subquery in InSubquery/ScalarSubquery to a given type. +/// Cast subquery in InSubquery/ScalarSubquery to a given type. +/// +/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific +/// columns), it casts the first expression in the projection to the target type and creates a +/// new projection with the casted expression. +/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan +/// with the casted first column. +/// pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { if subquery.subquery.schema().field(0).data_type() == cast_to_type { return Ok(subquery); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 9814d16ddfa3..fca45dfe1498 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -67,7 +67,7 @@ pub type StateTypeFunction = /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked /// * 'info': [crate::simplify::SimplifyInfo] /// -/// closure returns simplified [Expr] or an error. +///Cclosure returns simplified [Expr] or an error. pub type AggregateFunctionSimplification = Box< dyn Fn( crate::expr::AggregateFunction, @@ -80,7 +80,7 @@ pub type AggregateFunctionSimplification = Box< /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked /// * 'info': [crate::simplify::SimplifyInfo] /// -/// closure returns simplified [Expr] or an error. +/// Closure returns simplified [Expr] or an error. pub type WindowFunctionSimplification = Box< dyn Fn( crate::expr::WindowFunction, diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index a55cb49b1f40..e636fabf10fb 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -29,10 +29,10 @@ use crate::{execution_props::ExecutionProps, Expr, ExprSchemable}; /// information in without having to create `DFSchema` objects. If you /// have a [`DFSchemaRef`] you can use [`SimplifyContext`] pub trait SimplifyInfo { - /// returns true if this Expr has boolean type + /// Returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result; - /// returns true of this expr is nullable (could possibly be NULL) + /// Returns true of this expr is nullable (could possibly be NULL) fn nullable(&self, expr: &Expr) -> Result; /// Returns details needed for partial expression evaluation @@ -72,7 +72,7 @@ impl<'a> SimplifyContext<'a> { } impl<'a> SimplifyInfo for SimplifyContext<'a> { - /// returns true if this Expr has boolean type + /// Returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result { if let Some(schema) = &self.schema { if let Ok(DataType::Boolean) = expr.get_type(schema) { @@ -113,7 +113,7 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { pub enum ExprSimplifyResult { /// The function call was simplified to an entirely new Expr Simplified(Expr), - /// the function call could not be simplified, and the arguments + /// The function call could not be simplified, and the arguments /// are return unmodified. Original(Vec), } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index c7c498dd3f01..90afe5722abb 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Tree node implementation for logical expr +//! Tree node implementation for Logical Expressions use crate::expr::{ AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, @@ -28,7 +28,16 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{map_until_stop_and_collect, Result}; +/// Implementation of the [`TreeNode`] trait +/// +/// This allows logical expressions (`Expr`) to be traversed and transformed +/// Facilitates tasks such as optimization and rewriting during query +/// planning. impl TreeNode for Expr { + /// Applies a function `f` to each child expression of `self`. + /// + /// The function `f` determines whether to continue traversing the tree or to stop. + /// This method collects all child expressions and applies `f` to each. fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, f: F, @@ -122,6 +131,10 @@ impl TreeNode for Expr { children.into_iter().apply_until_stop(f) } + /// Maps each child of `self` using the provided closure `f`. + /// + /// The closure `f` takes ownership of an expression and returns a `Transformed` result, + /// indicating whether the expression was transformed or left unchanged. fn map_children Result>>( self, mut f: F, @@ -346,6 +359,7 @@ impl TreeNode for Expr { } } +/// Transforms a boxed expression by applying the provided closure `f`. fn transform_box Result>>( be: Box, f: &mut F, @@ -353,6 +367,7 @@ fn transform_box Result>>( Ok(f(*be)?.update_data(Box::new)) } +/// Transforms an optional boxed expression by applying the provided closure `f`. fn transform_option_box Result>>( obe: Option>, f: &mut F, @@ -380,6 +395,7 @@ fn transform_vec Result>>( ve.into_iter().map_until_stop_and_collect(f) } +/// Transforms an optional vector of sort expressions by applying the provided closure `f`. pub fn transform_sort_option_vec Result>>( sorts_option: Option>, f: &mut F, @@ -389,6 +405,7 @@ pub fn transform_sort_option_vec Result>>( }) } +/// Transforms an vector of sort expressions by applying the provided closure `f`. pub fn transform_sort_vec Result>>( sorts: Vec, mut f: &mut F, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 6e48054bcf3d..dbbf88447ba3 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -140,7 +140,7 @@ impl AggregateUDF { )) } - /// creates an [`Expr`] that calls the aggregate function. + /// Creates an [`Expr`] that calls the aggregate function. /// /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. @@ -603,8 +603,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { } /// If this function is max, return true - /// if the function is min, return false - /// otherwise return None (the default) + /// If the function is min, return false + /// Otherwise return None (the default) /// /// /// Note: this is used to use special aggregate implementations in certain conditions @@ -647,7 +647,7 @@ impl PartialEq for dyn AggregateUDFImpl { } } -// manual implementation of `PartialOrd` +// Manual implementation of `PartialOrd` // There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl // https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5 impl PartialOrd for dyn AggregateUDFImpl { diff --git a/datafusion/expr/src/udf_docs.rs b/datafusion/expr/src/udf_docs.rs index e0ce7526036e..8e255566606c 100644 --- a/datafusion/expr/src/udf_docs.rs +++ b/datafusion/expr/src/udf_docs.rs @@ -33,21 +33,21 @@ use datafusion_common::Result; /// thus all text should be in English. #[derive(Debug, Clone)] pub struct Documentation { - /// the section in the documentation where the UDF will be documented + /// The section in the documentation where the UDF will be documented pub doc_section: DocSection, - /// the description for the UDF + /// The description for the UDF pub description: String, - /// a brief example of the syntax. For example "ascii(str)" + /// A brief example of the syntax. For example "ascii(str)" pub syntax_example: String, - /// a sql example for the UDF, usually in the form of a sql prompt + /// A sql example for the UDF, usually in the form of a sql prompt /// query and output. It is strongly recommended to provide an /// example for anything but the most basic UDF's pub sql_example: Option, - /// arguments for the UDF which will be displayed in array order. + /// Arguments for the UDF which will be displayed in array order. /// Left member of a pair is the argument name, right is a /// description for the argument pub arguments: Option>, - /// related functions if any. Values should match the related + /// Related functions if any. Values should match the related /// udf's name exactly. Related udf's must be of the same /// UDF type (scalar, aggregate or window) for proper linking to /// occur @@ -63,12 +63,12 @@ impl Documentation { #[derive(Debug, Clone, PartialEq)] pub struct DocSection { - /// true to include this doc section in the public + /// True to include this doc section in the public /// documentation, false otherwise pub include: bool, - /// a display label for the doc section. For example: "Math Expressions" + /// A display label for the doc section. For example: "Math Expressions" pub label: &'static str, - /// an optional description for the doc section + /// An optional description for the doc section pub description: Option<&'static str>, } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 06cf1ec693f0..9ee13f1e06d3 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -205,7 +205,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { if !has_grouping_set || group_expr.len() == 1 { return Ok(group_expr); } - // only process mix grouping sets + // Only process mix grouping sets let partial_sets = group_expr .iter() .map(|expr| { @@ -234,7 +234,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { }) .collect::>>()?; - // cross join + // Cross Join let grouping_sets = partial_sets .into_iter() .map(Ok) @@ -342,7 +342,7 @@ fn get_excluded_columns( // Excluded columns should be unique let n_elem = idents.len(); let unique_idents = idents.into_iter().collect::>(); - // if HashSet size, and vector length are different, this means that some of the excluded columns + // If HashSet size, and vector length are different, this means that some of the excluded columns // are not unique. In this case return error. if n_elem != unique_idents.len() { return plan_err!("EXCLUDE or EXCEPT contains duplicate column names"); @@ -466,7 +466,7 @@ pub fn expand_qualified_wildcard( } /// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") -/// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column +/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column type WindowSortKey = Vec<(Sort, bool)>; /// Generate a sort key for a given window expr's partition_by and order_by expr @@ -573,7 +573,7 @@ pub fn compare_sort_expr( Ordering::Equal } -/// group a slice of window expression expr by their order by expressions +/// Group a slice of window expression expr by their order by expressions pub fn group_window_expr_by_sort_keys( window_expr: Vec, ) -> Result)>> { @@ -656,7 +656,7 @@ where if !(exprs.contains(expr)) { exprs.push(expr.clone()) } - // stop recursing down this expr once we find a match + // Stop recursing down this expr once we find a match return Ok(TreeNodeRecursion::Jump); } @@ -675,7 +675,7 @@ where let mut err = Ok(()); expr.apply(|expr| { if let Err(e) = f(expr) { - // save the error for later (it may not be a DataFusionError + // Save the error for later (it may not be a DataFusionError) err = Err(e); Ok(TreeNodeRecursion::Stop) } else { @@ -694,7 +694,7 @@ pub fn exprlist_to_fields<'a>( exprs: impl IntoIterator, plan: &LogicalPlan, ) -> Result, Arc)>> { - // look for exact match in plan's output schema + // Look for exact match in plan's output schema let wildcard_schema = find_base_plan(plan).schema(); let input_schema = plan.schema(); let result = exprs @@ -953,8 +953,8 @@ pub(crate) fn find_column_indexes_referenced_by_expr( indexes } -/// can this data type be used in hash join equal conditions?? -/// data types here come from function 'equal_rows', if more data types are supported +/// Can this data type be used in hash join equal conditions?? +/// Data types here come from function 'equal_rows', if more data types are supported /// in equal_rows(hash join), add those data types here to generate join logical plan. pub fn can_hash(data_type: &DataType) -> bool { match data_type { diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index e7f31bbfbf2b..f1d0ead23ab1 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -48,7 +48,7 @@ pub struct WindowAggState { /// Keeps track of how many rows should be generated to be in sync with input record_batch. // (For each row in the input record batch we need to generate a window result). pub n_row_result_missing: usize, - /// flag indicating whether we have received all data for this partition + /// Flag indicating whether we have received all data for this partition pub is_end: bool, } From 21cb3573c74166722d9c0b093328991545866b83 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 14 Oct 2024 07:41:34 -0400 Subject: [PATCH 09/19] Print undocumented functions to console while generating docs (#12874) --- .../core/src/bin/print_functions_docs.rs | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs index d9415028c124..d87c3cefe666 100644 --- a/datafusion/core/src/bin/print_functions_docs.rs +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -20,10 +20,16 @@ use datafusion_expr::{ aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, DocSection, Documentation, ScalarUDF, WindowUDF, }; +use hashbrown::HashSet; use itertools::Itertools; use std::env::args; use std::fmt::Write as _; +/// Print documentation for all functions of a given type to stdout +/// +/// Usage: `cargo run --bin print_functions_docs -- ` +/// +/// Called from `dev/update_function_docs.sh` fn main() { let args: Vec = args().collect(); @@ -83,9 +89,12 @@ fn print_docs( ) -> String { let mut docs = "".to_string(); + // Ensure that all providers have documentation + let mut providers_with_no_docs = HashSet::new(); + // doc sections only includes sections that have 'include' == true for doc_section in doc_sections { - // make sure there is a function that is in this doc section + // make sure there is at least one function that is in this doc section if !&providers.iter().any(|f| { if let Some(documentation) = f.get_documentation() { documentation.doc_section == doc_section @@ -96,12 +105,14 @@ fn print_docs( continue; } + // filter out functions that are not in this doc section let providers: Vec<&Box> = providers .iter() .filter(|&f| { if let Some(documentation) = f.get_documentation() { documentation.doc_section == doc_section } else { + providers_with_no_docs.insert(f.get_name()); false } }) @@ -202,9 +213,19 @@ fn print_docs( } } + // If there are any functions that do not have documentation, print them out + // eventually make this an error: https://github.com/apache/datafusion/issues/12872 + if !providers_with_no_docs.is_empty() { + eprintln!("INFO: The following functions do not have documentation:"); + for f in providers_with_no_docs { + eprintln!(" - {f}"); + } + } + docs } +/// Trait for accessing name / aliases / documentation for differnet functions trait DocProvider { fn get_name(&self) -> String; fn get_aliases(&self) -> Vec; From 16589b56a161a8ff13d16b3e55d7abbfb9d94f4b Mon Sep 17 00:00:00 2001 From: HuSen Date: Mon, 14 Oct 2024 19:42:31 +0800 Subject: [PATCH 10/19] Fix: handle NULL offset of NTH_VALUE window function (#12851) --- datafusion/physical-expr/src/window/nth_value.rs | 2 +- datafusion/sqllogictest/test_files/window.slt | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index d94983c5adf7..6ec3a23fc586 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -185,7 +185,7 @@ impl PartitionEvaluator for NthValueEvaluator { // Negative index represents reverse direction. (n_range >= reverse_index, true) } - Ordering::Equal => (true, false), + Ordering::Equal => (false, false), } } }; diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 40309a1f2de9..79cb91e183db 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -4929,6 +4929,15 @@ SELECT v1, NTH_VALUE(v2, 0) OVER (PARTITION BY v1 ORDER BY v2) FROM t; 2 NULL 2 NULL +query I +SELECT NTH_VALUE(tt0.v1, NULL) OVER (PARTITION BY tt0.v2 ORDER BY tt0.v1) FROM t AS tt0; +---- +NULL +NULL +NULL +NULL +NULL + statement ok DROP TABLE t; From 746380b303418b25fcc424cbfc1057d6b2c8f0dc Mon Sep 17 00:00:00 2001 From: Simon Vandel Sillesen Date: Mon, 14 Oct 2024 13:44:16 +0200 Subject: [PATCH 11/19] Optimize `signum` function (3-25x faster) (#12890) * add bench * optimize signum --- datafusion/functions/Cargo.toml | 5 ++ datafusion/functions/benches/signum.rs | 46 ++++++++++++++++++ datafusion/functions/src/math/signum.rs | 64 ++++++++++++------------- 3 files changed, 81 insertions(+), 34 deletions(-) create mode 100644 datafusion/functions/benches/signum.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 2ffe93a0e567..e08dfb2de07e 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -137,6 +137,11 @@ harness = false name = "to_char" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "signum" +required-features = ["math_expressions"] + [[bench]] harness = false name = "substr_index" diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs new file mode 100644 index 000000000000..9f8d8258c823 --- /dev/null +++ b/datafusion/functions/benches/signum.rs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::signum; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let signum = signum(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("signum f32 array: {}", size), |b| { + b.iter(|| black_box(signum.invoke(&f32_args).unwrap())) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("signum f64 array: {}", size), |b| { + b.iter(|| black_box(signum.invoke(&f64_args).unwrap())) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index d2a806a46e13..15b73f930343 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -18,11 +18,11 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -86,37 +86,33 @@ impl ScalarUDFImpl for SignumFunc { /// signum SQL function pub fn signum(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "signum", - Float64Array, - Float64Array, - { - |x: f64| { - if x == 0_f64 { - 0_f64 - } else { - x.signum() - } - } - } - )) as ArrayRef), - - Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "signum", - Float32Array, - Float32Array, - { - |x: f32| { - if x == 0_f32 { - 0_f32 - } else { - x.signum() - } - } - } - )) as ArrayRef), + Float64 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>( + |x: f64| { + if x == 0_f64 { + 0_f64 + } else { + x.signum() + } + }, + ), + ) as ArrayRef), + + Float32 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>( + |x: f32| { + if x == 0_f32 { + 0_f32 + } else { + x.signum() + } + }, + ), + ) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function signum"), } From 4ceb95067c5c95544b103f54b99be27ca63dee33 Mon Sep 17 00:00:00 2001 From: Michael J Ward Date: Mon, 14 Oct 2024 06:45:23 -0500 Subject: [PATCH 12/19] re-export PartitionEvaluatorArgs from datafusion_expr::function (#12878) This is needed to implement the WindowUDF trait. --- datafusion/expr/src/function.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index fca45dfe1498..199a91bf5ace 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -28,6 +28,7 @@ pub use datafusion_functions_aggregate_common::accumulator::{ }; pub use datafusion_functions_window_common::field::WindowUDFFieldArgs; +pub use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; #[derive(Debug, Clone, Copy)] pub enum Hint { From 6e4bf05b3bd7ee5b81ab8fb24eb98f236856ab3c Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Mon, 14 Oct 2024 19:46:13 +0800 Subject: [PATCH 13/19] Unparse Sort with pushdown limit to SQL string (#12873) * unparse Sort with push down limit * cargo fmt * set query limit directly --- datafusion/sql/src/unparser/plan.rs | 6 ++++++ datafusion/sql/tests/cases/plan_to_sql.rs | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index c4fcbb2d6458..d150f0e532c6 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -357,6 +357,12 @@ impl Unparser<'_> { return self.derive(plan, relation); } if let Some(query_ref) = query { + if let Some(fetch) = sort.fetch { + query_ref.limit(Some(ast::Expr::Value(ast::Value::Number( + fetch.to_string(), + false, + )))); + } query_ref.order_by(self.sorts_to_sql(sort.expr.clone())?); } else { return internal_err!( diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 903d4e28520b..aff9f99c8cd3 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -841,6 +841,26 @@ fn test_table_scan_pushdown() -> Result<()> { Ok(()) } +#[test] +fn test_sort_with_push_down_fetch() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let plan = table_scan(Some("t1"), &schema, None)? + .project(vec![col("id"), col("age")])? + .sort_with_limit(vec![col("age").sort(true, true)], Some(10))? + .build()?; + + let sql = plan_to_sql(&plan)?; + assert_eq!( + format!("{}", sql), + "SELECT t1.id, t1.age FROM t1 ORDER BY t1.age ASC NULLS FIRST LIMIT 10" + ); + Ok(()) +} + #[test] fn test_interval_lhs_eq() { sql_round_trip( From 6c0670d1c42bf13b74c5edf6880f044f8ca3b818 Mon Sep 17 00:00:00 2001 From: Yongting You <2010youy01@gmail.com> Date: Mon, 14 Oct 2024 19:49:40 +0800 Subject: [PATCH 14/19] Add spilling related metrics for aggregation (#12888) * External aggregation metrics * clippy --- .../physical-plan/src/aggregates/mod.rs | 12 ++++++ .../physical-plan/src/aggregates/row_hash.rs | 41 ++++++++++++++++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index d6f16fb0fdd3..296c5811e577 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1686,12 +1686,24 @@ mod tests { let metrics = merged_aggregate.metrics().unwrap(); let output_rows = metrics.output_rows().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + let spilled_bytes = metrics.spilled_bytes().unwrap(); + let spilled_rows = metrics.spilled_rows().unwrap(); + if spill { // When spilling, the output rows metrics become partial output size + final output size // This is because final aggregation starts while partial aggregation is still emitting assert_eq!(8, output_rows); + + assert!(spill_count > 0); + assert!(spilled_bytes > 0); + assert!(spilled_rows > 0); } else { assert_eq!(3, output_rows); + + assert_eq!(0, spill_count); + assert_eq!(0, spilled_bytes); + assert_eq!(0, spilled_rows); } Ok(()) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 5121e6cc3b35..624844b6b985 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -102,6 +102,19 @@ struct SpillState { /// true when streaming merge is in progress is_stream_merging: bool, + + // ======================================================================== + // METRICS: + // ======================================================================== + /// Peak memory used for buffered data. + /// Calculated as sum of peak memory values across partitions + peak_mem_used: metrics::Gauge, + /// count of spill files during the execution of the operator + spill_count: metrics::Count, + /// total spilled bytes during the execution of the operator + spilled_bytes: metrics::Count, + /// total spilled rows during the execution of the operator + spilled_rows: metrics::Count, } /// Tracks if the aggregate should skip partial aggregations @@ -138,6 +151,9 @@ struct SkipAggregationProbe { /// make any effect (set either while probing or on probing completion) is_locked: bool, + // ======================================================================== + // METRICS: + // ======================================================================== /// Number of rows where state was output without aggregation. /// /// * If 0, all input rows were aggregated (should_skip was always false) @@ -510,6 +526,11 @@ impl GroupedHashAggregateStream { is_stream_merging: false, merging_aggregate_arguments, merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), + peak_mem_used: MetricBuilder::new(&agg.metrics) + .gauge("peak_mem_used", partition), + spill_count: MetricBuilder::new(&agg.metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(&agg.metrics).spilled_bytes(partition), + spilled_rows: MetricBuilder::new(&agg.metrics).spilled_rows(partition), }; // Skip aggregation is supported if: @@ -865,11 +886,19 @@ impl GroupedHashAggregateStream { fn update_memory_reservation(&mut self) -> Result<()> { let acc = self.accumulators.iter().map(|x| x.size()).sum::(); - self.reservation.try_resize( + let reservation_result = self.reservation.try_resize( acc + self.group_values.size() + self.group_ordering.size() + self.current_group_indices.allocated_size(), - ) + ); + + if reservation_result.is_ok() { + self.spill_state + .peak_mem_used + .set_max(self.reservation.size()); + } + + reservation_result } /// Create an output RecordBatch with the group keys and @@ -946,6 +975,14 @@ impl GroupedHashAggregateStream { self.batch_size, )?; self.spill_state.spills.push(spillfile); + + // Update metrics + self.spill_state.spill_count.add(1); + self.spill_state + .spilled_bytes + .add(sorted.get_array_memory_size()); + self.spill_state.spilled_rows.add(sorted.num_rows()); + Ok(()) } From 5391c98f7a3fda1f8eef994591286b1596033bc5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 14 Oct 2024 08:52:10 -0400 Subject: [PATCH 15/19] Move equivalence fuzz testing to fuzz test binary (#12767) * Move equivalence fuzz testing to fuzz test binary * fix license * fixup --- .../core/tests/fuzz_cases/equivalence/mod.rs | 23 + .../tests/fuzz_cases/equivalence/ordering.rs | 160 ++++++ .../fuzz_cases/equivalence/projection.rs | 200 ++++++++ .../fuzz_cases/equivalence/properties.rs | 105 ++++ .../tests/fuzz_cases/equivalence/utils.rs | 463 ++++++++++++++++++ datafusion/core/tests/fuzz_cases/mod.rs | 2 + .../physical-expr/src/equivalence/mod.rs | 89 ---- .../physical-expr/src/equivalence/ordering.rs | 138 +----- .../src/equivalence/projection.rs | 177 +------ .../src/equivalence/properties.rs | 84 +--- 10 files changed, 958 insertions(+), 483 deletions(-) create mode 100644 datafusion/core/tests/fuzz_cases/equivalence/mod.rs create mode 100644 datafusion/core/tests/fuzz_cases/equivalence/ordering.rs create mode 100644 datafusion/core/tests/fuzz_cases/equivalence/projection.rs create mode 100644 datafusion/core/tests/fuzz_cases/equivalence/properties.rs create mode 100644 datafusion/core/tests/fuzz_cases/equivalence/utils.rs diff --git a/datafusion/core/tests/fuzz_cases/equivalence/mod.rs b/datafusion/core/tests/fuzz_cases/equivalence/mod.rs new file mode 100644 index 000000000000..2f8a38200bf1 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `EquivalenceProperties` fuzz testing + +mod ordering; +mod projection; +mod properties; +mod utils; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs new file mode 100644 index 000000000000..b1ee24a7a373 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, + TestScalarUDF, +}; +use arrow_schema::SortOptions; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 5; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + let col_exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + ]; + + for n_req in 0..=col_exprs.len() { + for exprs in col_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + + Ok(()) +} + +#[test] +fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(&requirement), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs new file mode 100644 index 000000000000..c0c8517a612b --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + apply_projection, create_random_schema, generate_table_for_eq_properties, + is_table_same_after_sort, TestScalarUDF, +}; +use arrow_schema::SortOptions; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn project_orderings_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + // Make sure each ordering after projection is valid. + for ordering in projected_eq.oeq_class().iter() { + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs + ); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + projected_batch.clone(), + )?, + "{}", + err_msg + ); + } + } + } + } + + Ok(()) +} + +#[test] +fn ordering_satisfy_after_projection_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + + let projected_exprs = projection_mapping + .iter() + .map(|(_source, target)| Arc::clone(target)) + .collect::>(); + + for n_req in 0..=projected_exprs.len() { + for exprs in projected_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::>(); + let expected = is_table_same_after_sort( + requirement.clone(), + projected_batch.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + projected_eq.ordering_satisfy(&requirement), + expected, + "{}", + err_msg + ); + } + } + } + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs new file mode 100644 index 000000000000..e704fcacc328 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, + TestScalarUDF, +}; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: Arc::clone(&exprs[idx]), + options: sort_expr.options, + }) + .collect::>(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs new file mode 100644 index 000000000000..e51dabd6437f --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -0,0 +1,463 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// use datafusion_physical_expr::expressions::{col, Column}; +use datafusion::physical_plan::expressions::col; +use datafusion::physical_plan::expressions::Column; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties, PhysicalSortExpr}; +use std::any::Any; +use std::sync::Arc; + +use arrow::compute::{lexsort_to_indices, SortColumn}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; + +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use itertools::izip; +use rand::prelude::*; + +pub fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, +) -> Result { + // Calculate output schema + let fields: Result> = mapping + .iter() + .map(|(source, target)| { + let name = target + .as_any() + .downcast_ref::() + .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? + .name(); + let field = Field::new( + name, + source.data_type(input_schema)?, + source.nullable(input_schema)?, + ); + + Ok(field) + }) + .collect(); + + let output_schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + Ok(output_schema) +} + +// Generate a schema which consists of 6 columns (a, b, c, d, e, f) +fn create_test_schema_2() -> Result { + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) +} + +/// Construct a schema with random ordering +/// among column a, b, c, d +/// where +/// Column [a=f] (e.g they are aliases). +/// Column e is constant. +pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema_2()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f)?; + // Column e has constant value. + eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) +} + +// Apply projection to the input_data, return projected equivalence properties and record batch +pub fn apply_projection( + proj_exprs: Vec<(Arc, String)>, + input_data: &RecordBatch, + input_eq_properties: &EquivalenceProperties, +) -> Result<(RecordBatch, EquivalenceProperties)> { + let input_schema = input_data.schema(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let output_schema = output_schema(&projection_mapping, &input_schema)?; + let num_rows = input_data.num_rows(); + // Apply projection to the input record batch. + let projected_values = projection_mapping + .iter() + .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) + .collect::>>()?; + let projected_batch = if projected_values.is_empty() { + RecordBatch::new_empty(Arc::clone(&output_schema)) + } else { + RecordBatch::try_new(Arc::clone(&output_schema), projected_values)? + }; + + let projected_eq = input_eq_properties.project(&projection_mapping, output_schema); + Ok((projected_batch, projected_eq)) +} + +#[test] +fn add_equal_conditions_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Int64, true), + ])); + + let mut eq_properties = EquivalenceProperties::new(schema); + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; + let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + + // a and b are aliases + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + + // This new entry is redundant, size shouldn't increase + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 2); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + + // b and c are aliases. Exising equivalence class should expand, + // however there shouldn't be any new equivalence class + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 3); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + + // This is a new set of equality. Hence equivalent class count should be 2. + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + assert_eq!(eq_properties.eq_group().len(), 2); + + // This equality bridges distinct equality sets. + // Hence equivalent class count should decrease from 2 to 1. + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 5); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); + + Ok(()) +} + +/// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. +/// +/// The function works by adding a unique column of ascending integers to the original table. This column ensures +/// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can +/// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce +/// deterministic sorting results. +/// +/// If the table remains the same after sorting with the added unique column, it indicates that the table was +/// already sorted according to `required_ordering` to begin with. +pub fn is_table_same_after_sort( + mut required_ordering: Vec, + batch: RecordBatch, +) -> Result { + // Clone the original schema and columns + let original_schema = batch.schema(); + let mut columns = batch.columns().to_vec(); + + // Create a new unique column + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; + columns.push(Arc::clone(&unique_col)); + + // Create a new schema with the added unique column + let unique_col_name = "unique"; + let unique_field = Arc::new(Field::new(unique_col_name, DataType::Float64, false)); + let fields: Vec<_> = original_schema + .fields() + .iter() + .cloned() + .chain(std::iter::once(unique_field)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create a new batch with the added column + let new_batch = RecordBatch::try_new(Arc::clone(&schema), columns)?; + + // Add the unique column to the required ordering to ensure deterministic results + required_ordering.push(PhysicalSortExpr { + expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), + options: Default::default(), + }); + + // Convert the required ordering to a list of SortColumn + let sort_columns = required_ordering + .iter() + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, + options: Some(order_expr.options), + }) + }) + .collect::>>()?; + + // Check if the indices after sorting match the initial ordering + let sorted_indices = lexsort_to_indices(&sort_columns, None)?; + let original_indices = UInt32Array::from_iter_values(0..n_row as u32); + + Ok(sorted_indices == original_indices) +} + +// If we already generated a random result for one of the +// expressions in the equivalence classes. For other expressions in the same +// equivalence class use same result. This util gets already calculated result, when available. +fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, +) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(Arc::clone(res)); + } + } + None +} + +// Generate a table that satisfies the given equivalence properties; i.e. +// equivalences, ordering equivalences, and constants. +pub fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, +) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in &eq_properties.constants { + let col = constant.expr().as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = + Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class.iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group.iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(Arc::clone(&representative_array)); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) +} + +#[derive(Debug, Clone)] +pub struct TestScalarUDF { + pub(crate) signature: Signature, +} + +impl TestScalarUDF { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "test-scalar-udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new({ + let arg = &args[0].as_any().downcast_ref::().ok_or_else( + || { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + }, + )?; + + arg.iter() + .map(|a| a.map(f64::floor)) + .collect::() + }), + DataType::Float32 => Arc::new({ + let arg = &args[0].as_any().downcast_ref::().ok_or_else( + || { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + }, + )?; + + arg.iter() + .map(|a| a.map(f32::floor)) + .collect::() + }), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + Ok(ColumnarValue::Array(arr)) + } +} diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 5bc36b963c44..49db0d31a8e9 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -22,6 +22,8 @@ mod merge_fuzz; mod sort_fuzz; mod aggregation_fuzzer; +mod equivalence; + mod limit_fuzz; mod sort_preserving_repartition_fuzz; mod window_fuzz; diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 38647f7ca1d4..7726458a46ac 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -84,7 +84,6 @@ mod tests { use itertools::izip; use rand::rngs::StdRng; - use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; pub fn output_schema( @@ -175,67 +174,6 @@ mod tests { Ok((test_schema, eq_properties)) } - // Generate a schema which consists of 6 columns (a, b, c, d, e, f) - fn create_test_schema_2() -> Result { - let a = Field::new("a", DataType::Float64, true); - let b = Field::new("b", DataType::Float64, true); - let c = Field::new("c", DataType::Float64, true); - let d = Field::new("d", DataType::Float64, true); - let e = Field::new("e", DataType::Float64, true); - let f = Field::new("f", DataType::Float64, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); - - Ok(schema) - } - - /// Construct a schema with random ordering - /// among column a, b, c, d - /// where - /// Column [a=f] (e.g they are aliases). - /// Column e is constant. - pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema_2()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; - - let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); - // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f)?; - // Column e has constant value. - eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); - - // Randomly order columns for sorting - let mut rng = StdRng::seed_from_u64(seed); - let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted - - let options_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); - remaining_exprs.shuffle(&mut rng); - - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: options_asc, - }) - .collect(); - - eq_properties.add_new_orderings([ordering]); - } - - Ok((test_schema, eq_properties)) - } - // Convert each tuple to PhysicalSortRequirement pub fn convert_to_sort_reqs( in_data: &[(&Arc, Option)], @@ -294,33 +232,6 @@ mod tests { .collect() } - // Apply projection to the input_data, return projected equivalence properties and record batch - pub fn apply_projection( - proj_exprs: Vec<(Arc, String)>, - input_data: &RecordBatch, - input_eq_properties: &EquivalenceProperties, - ) -> Result<(RecordBatch, EquivalenceProperties)> { - let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - let output_schema = output_schema(&projection_mapping, &input_schema)?; - let num_rows = input_data.num_rows(); - // Apply projection to the input record batch. - let projected_values = projection_mapping - .iter() - .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) - .collect::>>()?; - let projected_batch = if projected_values.is_empty() { - RecordBatch::new_empty(Arc::clone(&output_schema)) - } else { - RecordBatch::try_new(Arc::clone(&output_schema), projected_values)? - }; - - let projected_eq = - input_eq_properties.project(&projection_mapping, output_schema); - Ok((projected_batch, projected_eq)) - } - #[test] fn add_equal_conditions_test() -> Result<()> { let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index bb3e9218bc41..a3cf8c965b69 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -254,9 +254,8 @@ mod tests { use std::sync::Arc; use crate::equivalence::tests::{ - convert_to_orderings, convert_to_sort_exprs, create_random_schema, - create_test_params, create_test_schema, generate_table_for_eq_properties, - is_table_same_after_sort, + convert_to_orderings, convert_to_sort_exprs, create_test_params, + create_test_schema, generate_table_for_eq_properties, is_table_same_after_sort, }; use crate::equivalence::{ EquivalenceClass, EquivalenceGroup, EquivalenceProperties, @@ -271,8 +270,6 @@ mod tests { use datafusion_common::{DFSchema, Result}; use datafusion_expr::{Operator, ScalarUDF}; - use itertools::Itertools; - #[test] fn test_ordering_satisfy() -> Result<()> { let input_schema = Arc::new(Schema::new(vec![ @@ -771,137 +768,6 @@ mod tests { Ok(()) } - #[test] - fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 5; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - ]; - - for n_req in 0..=col_exprs.len() { - for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - (expected | false), - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - #[test] fn test_ordering_satisfy_different_lengths() -> Result<()> { let test_schema = create_test_schema()?; diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index ebf26d3262aa..25a05a2a5918 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -139,23 +139,18 @@ fn project_index_to_exprs( mod tests { use super::*; use crate::equivalence::tests::{ - apply_projection, convert_to_orderings, convert_to_orderings_owned, - create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, - output_schema, + convert_to_orderings, convert_to_orderings_owned, output_schema, }; use crate::equivalence::EquivalenceProperties; use crate::expressions::{col, BinaryExpr}; use crate::udf::create_physical_expr; use crate::utils::tests::TestScalarUDF; - use crate::PhysicalSortExpr; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{SortOptions, TimeUnit}; use datafusion_common::DFSchema; use datafusion_expr::{Operator, ScalarUDF}; - use itertools::Itertools; - #[test] fn project_orderings() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -987,174 +982,4 @@ mod tests { Ok(()) } - - #[test] - fn project_orderings_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - // Make sure each ordering after projection is valid. - for ordering in projected_eq.oeq_class().iter() { - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs - ); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, - "{}", - err_msg - ); - } - } - } - } - - Ok(()) - } - - #[test] - fn ordering_satisfy_after_projection_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (Arc::clone(expr), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; - - let projected_exprs = projection_mapping - .iter() - .map(|(_source, target)| Arc::clone(target)) - .collect::>(); - - for n_req in 0..=projected_exprs.len() { - for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: Arc::clone(expr), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - projected_eq.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - } - } - - Ok(()) - } } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 005e5776d3ae..a0cc29685f77 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -2101,16 +2101,13 @@ mod tests { use crate::equivalence::add_offset_to_expr; use crate::equivalence::tests::{ convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, - create_random_schema, create_test_params, create_test_schema, - generate_table_for_eq_properties, is_table_same_after_sort, output_schema, + create_test_params, create_test_schema, output_schema, }; use crate::expressions::{col, BinaryExpr, Column}; - use crate::utils::tests::TestScalarUDF; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{Fields, TimeUnit}; - use datafusion_common::DFSchema; - use datafusion_expr::{Operator, ScalarUDF}; + use datafusion_expr::Operator; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -2621,83 +2618,6 @@ mod tests { Ok(()) } - #[test] - fn test_find_longest_permutation_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = - eq_properties.find_longest_permutation(&exprs); - // Make sure that find_longest_permutation return values are consistent - let ordering2 = indices - .iter() - .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: Arc::clone(&exprs[idx]), - options: sort_expr.options, - }) - .collect::>(); - assert_eq!( - ordering, ordering2, - "indices and lexicographical ordering do not match" - ); - - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - assert_eq!(ordering.len(), indices.len(), "{}", err_msg); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } #[test] fn test_find_longest_permutation() -> Result<()> { // Schema satisfies following orderings: From f2564b7ae5e63e13f835f3ce3719f503bee9be08 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Tue, 15 Oct 2024 01:20:12 +0800 Subject: [PATCH 16/19] Remove unused `math_expressions.rs` (#12917) --- datafusion/physical-expr/src/lib.rs | 1 - .../physical-expr/src/math_expressions.rs | 126 ------------------ datafusion/sqllogictest/test_files/math.slt | 7 +- 3 files changed, 6 insertions(+), 128 deletions(-) delete mode 100644 datafusion/physical-expr/src/math_expressions.rs diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 46185712413e..e7c2b4119c5a 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -27,7 +27,6 @@ pub mod binary_map { pub mod equivalence; pub mod expressions; pub mod intervals; -pub mod math_expressions; mod partitioning; mod physical_expr; pub mod planner; diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs deleted file mode 100644 index 503565b1e261..000000000000 --- a/datafusion/physical-expr/src/math_expressions.rs +++ /dev/null @@ -1,126 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Math expressions - -use std::any::type_name; -use std::sync::Arc; - -use arrow::array::ArrayRef; -use arrow::array::{BooleanArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use arrow_array::Array; - -use datafusion_common::exec_err; -use datafusion_common::{DataFusionError, Result}; - -macro_rules! downcast_arg { - ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast {} from {} to {}", - $NAME, - $ARG.data_type(), - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} - -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARGS_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARGS_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - -/// Isnan SQL function -pub fn isnan(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { f64::is_nan } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { f32::is_nan } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function isnan"), - } -} - -#[cfg(test)] -mod tests { - - use datafusion_common::cast::as_boolean_array; - - use super::*; - - #[test] - fn test_isnan_f64() { - let args: Vec = vec![Arc::new(Float64Array::from(vec![ - 1.0, - f64::NAN, - 3.0, - -f64::NAN, - ]))]; - - let result = isnan(&args).expect("failed to initialize function isnan"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function isnan"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } - - #[test] - fn test_isnan_f32() { - let args: Vec = vec![Arc::new(Float32Array::from(vec![ - 1.0, - f32::NAN, - 3.0, - f32::NAN, - ]))]; - - let result = isnan(&args).expect("failed to initialize function isnan"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function isnan"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } -} diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index eece56942317..1bc972a3e37d 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -102,7 +102,12 @@ SELECT nanvl(asin(10), 1.0), nanvl(1.0, 2.0), nanvl(asin(10), asin(10)) # isnan query BBBB -SELECT isnan(1.0), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE), isnan(NULL) +SELECT isnan(1.0::DOUBLE), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE), isnan(NULL) +---- +false true true NULL + +query BBBB +SELECT isnan(1.0::FLOAT), isnan('NaN'::FLOAT), isnan(-'NaN'::FLOAT), isnan(NULL::FLOAT) ---- false true true NULL From fce331a4f7ca908ee7b4500d104fae9ec657f471 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Mon, 14 Oct 2024 13:20:21 -0400 Subject: [PATCH 17/19] Migrate Regex Functions from static docs (#12886) * regex migrate * small fixes * update docs --------- Co-authored-by: Andrew Lamb --- datafusion/functions/src/regex/regexpmatch.rs | 49 +++++++++- .../functions/src/regex/regexpreplace.rs | 52 +++++++++- .../source/user-guide/sql/scalar_functions.md | 97 ------------------- .../user-guide/sql/scalar_functions_new.md | 82 ++++++++++++++++ 4 files changed, 178 insertions(+), 102 deletions(-) diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index bfec97f92c36..443e50533268 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -26,10 +26,11 @@ use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; -use datafusion_expr::{ColumnarValue, TypeSignature}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct RegexpMatchFunc { @@ -106,7 +107,51 @@ impl ScalarUDFImpl for RegexpMatchFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_match_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_match_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.") + .with_syntax_example("regexp_match(str, regexp[, flags])") + .with_sql_example(r#"```sql + > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_match('aBc', '(b|d)', 'i'); + +---------------------------------------------------+ + | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#) + .with_standard_argument("str", "String") + .with_argument("regexp","Regular expression to match against. + Can be a constant, column, or function.") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() + }) } + fn regexp_match_func(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8 => regexp_match::(args), diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index bce8752af28b..279e5c6ba9dd 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -32,14 +32,15 @@ use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; use datafusion_expr::function::Hint; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; use regex::Regex; use std::any::Any; use std::collections::HashMap; -use std::sync::Arc; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; + #[derive(Debug)] pub struct RegexpReplaceFunc { signature: Signature, @@ -123,6 +124,51 @@ impl ScalarUDFImpl for RegexpReplaceFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_replace_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_replace_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax).") + .with_syntax_example("regexp_replace(str, regexp, replacement[, flags])") + .with_sql_example(r#"```sql +> select regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ++------------------------------------------------------------------------+ +| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | ++------------------------------------------------------------------------+ +| fooXarYXazY | ++------------------------------------------------------------------------+ +SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); ++-------------------------------------------------------------------+ +| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | ++-------------------------------------------------------------------+ +| aAbBac | ++-------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#) + .with_standard_argument("str", "String") + .with_argument("regexp","Regular expression to match against. + Can be a constant, column, or function.") + .with_standard_argument("replacement", "Replacement string") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: +- **g**: (global) Search globally and don't return after the first match +- **i**: case-insensitive: letters match both upper and lower case +- **m**: multi-line mode: ^ and $ match begin/end of line +- **s**: allow . to match \n +- **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used +- **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() +}) } fn regexp_replace_func(args: &[ColumnarValue]) -> Result { diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 95762d958521..56145ec803e0 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -563,103 +563,6 @@ See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/ See the new documentation [`here`](https://datafusion.apache.org/user-guide/sql/scalar_functions_new.html) -## Regular Expression Functions - -Apache DataFusion uses a [PCRE-like] regular expression [syntax] -(minus support for several features including look-around and backreferences). -The following regular expression functions are supported: - -- [regexp_match](#regexp_match) -- [regexp_replace](#regexp_replace) - -[pcre-like]: https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions -[syntax]: https://docs.rs/regex/latest/regex/#syntax - -### `regexp_match` - -Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. - -``` -regexp_match(str, regexp[, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **regexp**: Regular expression to match against. - Can be a constant, column, or function. -- **flags**: Optional regular expression flags that control the behavior of the - regular expression. The following flags are supported: - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? - -#### Example - -```sql -select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); -+---------------------------------------------------------+ -| regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | -+---------------------------------------------------------+ -| [Köln] | -+---------------------------------------------------------+ -SELECT regexp_match('aBc', '(b|d)', 'i'); -+---------------------------------------------------+ -| regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | -+---------------------------------------------------+ -| [B] | -+---------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - -### `regexp_replace` - -Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax). - -``` -regexp_replace(str, regexp, replacement[, flags]) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **regexp**: Regular expression to match against. - Can be a constant, column, or function. -- **replacement**: Replacement string expression. - Can be a constant, column, or function, and any combination of string operators. -- **flags**: Optional regular expression flags that control the behavior of the - regular expression. The following flags are supported: - - **g**: (global) Search globally and don't return after the first match - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? - -#### Example - -```sql -SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); -+------------------------------------------------------------------------+ -| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | -+------------------------------------------------------------------------+ -| fooXarYXazY | -+------------------------------------------------------------------------+ -SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); -+-------------------------------------------------------------------+ -| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | -+-------------------------------------------------------------------+ -| aAbBac | -+-------------------------------------------------------------------+ -``` - -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) - ### `position` Returns the position of `substr` in `origstr` (counting from 1). If `substr` does diff --git a/docs/source/user-guide/sql/scalar_functions_new.md b/docs/source/user-guide/sql/scalar_functions_new.md index 96fbcaa1104b..7d0261da0ad0 100644 --- a/docs/source/user-guide/sql/scalar_functions_new.md +++ b/docs/source/user-guide/sql/scalar_functions_new.md @@ -1191,6 +1191,8 @@ regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) The following regular expression functions are supported: - [regexp_like](#regexp_like) +- [regexp_match](#regexp_match) +- [regexp_replace](#regexp_replace) ### `regexp_like` @@ -1230,6 +1232,86 @@ SELECT regexp_like('aBc', '(b|d)', 'i'); Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +### `regexp_match` + +Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. + +``` +regexp_match(str, regexp[, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to match against. + Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? + +#### Example + +```sql + > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_match('aBc', '(b|d)', 'i'); + +---------------------------------------------------+ + | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) + +### `regexp_replace` + +Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax). + +``` +regexp_replace(str, regexp, replacement[, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to match against. + Can be a constant, column, or function. +- **replacement**: Replacement string expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: +- **g**: (global) Search globally and don't return after the first match +- **i**: case-insensitive: letters match both upper and lower case +- **m**: multi-line mode: ^ and $ match begin/end of line +- **s**: allow . to match \n +- **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used +- **U**: swap the meaning of x* and x*? + +#### Example + +```sql +> select regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ++------------------------------------------------------------------------+ +| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | ++------------------------------------------------------------------------+ +| fooXarYXazY | ++------------------------------------------------------------------------+ +SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); ++-------------------------------------------------------------------+ +| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | ++-------------------------------------------------------------------+ +| aAbBac | ++-------------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) + ## Time and Date Functions - [to_date](#to_date) From 377a4c553b04fbcf7609384a501af9a30fe02dbe Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 14 Oct 2024 22:24:48 -0400 Subject: [PATCH 18/19] Improve AggregationFuzzer error reporting (#12832) * Improve AggregationFuzzer error reporting * simplify * Update datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs * fmt --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 2 +- .../fuzz_cases/aggregation_fuzzer/fuzzer.rs | 90 ++++++++++++------- 2 files changed, 59 insertions(+), 33 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 64a7514ebd5e..34061a64d783 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -83,7 +83,7 @@ async fn test_basic_prim_aggr_no_group() { .table_name("fuzz_table") .build(); - fuzzer.run().await; + fuzzer.run().await } /// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by single int64` diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs index abb34048284d..6daebc894272 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use arrow::util::pretty::pretty_format_batches; use arrow_array::RecordBatch; +use datafusion_common::{DataFusionError, Result}; use rand::{thread_rng, Rng}; use tokio::task::JoinSet; @@ -132,7 +133,20 @@ struct QueryGroup { } impl AggregationFuzzer { + /// Run the fuzzer, printing an error and panicking if any of the tasks fail pub async fn run(&self) { + let res = self.run_inner().await; + + if let Err(e) = res { + // Print the error via `Display` so that it displays nicely (the default `unwrap()` + // prints using `Debug` which escapes newlines, and makes multi-line messages + // hard to read + println!("{e}"); + panic!("Error!"); + } + } + + async fn run_inner(&self) -> Result<()> { let mut join_set = JoinSet::new(); let mut rng = thread_rng(); @@ -157,16 +171,20 @@ impl AggregationFuzzer { let tasks = self.generate_fuzz_tasks(query_groups).await; for task in tasks { - join_set.spawn(async move { - task.run().await; - }); + join_set.spawn(async move { task.run().await }); } } while let Some(join_handle) = join_set.join_next().await { // propagate errors - join_handle.unwrap(); + join_handle.map_err(|e| { + DataFusionError::Internal(format!( + "AggregationFuzzer task error: {:?}", + e + )) + })??; } + Ok(()) } async fn generate_fuzz_tasks( @@ -237,45 +255,53 @@ struct AggregationFuzzTestTask { } impl AggregationFuzzTestTask { - async fn run(&self) { + async fn run(&self) -> Result<()> { let task_result = run_sql(&self.sql, &self.ctx_with_params.ctx) .await - .expect("should success to run sql"); - self.check_result(&task_result, &self.expected_result); + .map_err(|e| e.context(self.context_error_report()))?; + self.check_result(&task_result, &self.expected_result) } - // TODO: maybe we should persist the `expected_result` and `task_result`, - // because the readability is not so good if we just print it. - fn check_result(&self, task_result: &[RecordBatch], expected_result: &[RecordBatch]) { - let result = check_equality_of_batches(task_result, expected_result); - if let Err(e) = result { + fn check_result( + &self, + task_result: &[RecordBatch], + expected_result: &[RecordBatch], + ) -> Result<()> { + check_equality_of_batches(task_result, expected_result).map_err(|e| { // If we found inconsistent result, we print the test details for reproducing at first - println!( - "##### AggregationFuzzer error report ##### - ### Sql:\n{}\n\ - ### Schema:\n{}\n\ - ### Session context params:\n{:?}\n\ - ### Inconsistent row:\n\ - - row_idx:{}\n\ - - task_row:{}\n\ - - expected_row:{}\n\ - ### Task total result:\n{}\n\ - ### Expected total result:\n{}\n\ - ### Input:\n{}\n\ - ", - self.sql, - self.dataset_ref.batches[0].schema_ref(), - self.ctx_with_params.params, + let message = format!( + "{}\n\ + ### Inconsistent row:\n\ + - row_idx:{}\n\ + - task_row:{}\n\ + - expected_row:{}\n\ + ### Task total result:\n{}\n\ + ### Expected total result:\n{}\n\ + ", + self.context_error_report(), e.row_idx, e.lhs_row, e.rhs_row, pretty_format_batches(task_result).unwrap(), pretty_format_batches(expected_result).unwrap(), - pretty_format_batches(&self.dataset_ref.batches).unwrap(), ); + DataFusionError::Internal(message) + }) + } - // Then we just panic - panic!(); - } + /// Returns a formatted error message + fn context_error_report(&self) -> String { + format!( + "##### AggregationFuzzer error report #####\n\ + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, + pretty_format_batches(&self.dataset_ref.batches).unwrap(), + ) } } From d9450da6991a2977c617dfa789f208c53ba11421 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 15 Oct 2024 08:40:34 +0200 Subject: [PATCH 19/19] Import Arc consistently (#12899) Just a code cleanup In many places the qualified name is redundant since `Arc` is already imported. In some `use` was added. In all cases `Arc` is unambiguous. --- .../examples/custom_file_format.rs | 4 +-- datafusion/functions/src/regex/mod.rs | 4 ++- .../src/replace_distinct_aggregate.rs | 4 ++- datafusion/physical-plan/src/metrics/value.rs | 4 +-- .../proto/src/logical_plan/file_formats.rs | 30 +++++++------------ .../tests/cases/roundtrip_logical_plan.rs | 3 +- 6 files changed, 21 insertions(+), 28 deletions(-) diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs index 1d9b587f15b9..b85127d42f71 100644 --- a/datafusion-examples/examples/custom_file_format.rs +++ b/datafusion-examples/examples/custom_file_format.rs @@ -154,7 +154,7 @@ impl FileFormatFactory for TSVFileFactory { &self, state: &SessionState, format_options: &std::collections::HashMap, - ) -> Result> { + ) -> Result> { let mut new_options = format_options.clone(); new_options.insert("format.delimiter".to_string(), "\t".to_string()); @@ -164,7 +164,7 @@ impl FileFormatFactory for TSVFileFactory { Ok(tsv_file_format) } - fn default(&self) -> std::sync::Arc { + fn default(&self) -> Arc { todo!() } diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 4afbe6cbbb89..cde777311aa1 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -17,6 +17,8 @@ //! "regex" DataFusion functions +use std::sync::Arc; + pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; @@ -67,6 +69,6 @@ pub mod expr_fn { } /// Returns all DataFusion functions defined in this package -pub fn functions() -> Vec> { +pub fn functions() -> Vec> { vec![regexp_match(), regexp_like(), regexp_replace()] } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index c026130c426f..f3e1673e7211 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -16,8 +16,10 @@ // under the License. //! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...` + use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; +use std::sync::Arc; use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; @@ -110,7 +112,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { let expr_cnt = on_expr.len(); // Construct the aggregation expression to be used to fetch the selected expressions. - let first_value_udaf: std::sync::Arc = + let first_value_udaf: Arc = config.function_registry().unwrap().udaf("first_value")?; let aggr_expr = select_expr.into_iter().map(|e| { if let Some(order_by) = &sort_expr { diff --git a/datafusion/physical-plan/src/metrics/value.rs b/datafusion/physical-plan/src/metrics/value.rs index 22db8f1e4e88..5a335d9f99cd 100644 --- a/datafusion/physical-plan/src/metrics/value.rs +++ b/datafusion/physical-plan/src/metrics/value.rs @@ -37,7 +37,7 @@ use parking_lot::Mutex; #[derive(Debug, Clone)] pub struct Count { /// value of the metric counter - value: std::sync::Arc, + value: Arc, } impl PartialEq for Count { @@ -86,7 +86,7 @@ impl Count { #[derive(Debug, Clone)] pub struct Gauge { /// value of the metric gauge - value: std::sync::Arc, + value: Arc, } impl PartialEq for Gauge { diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs index 0f9f9d335afe..98034e3082af 100644 --- a/datafusion/proto/src/logical_plan/file_formats.rs +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -180,16 +180,14 @@ impl LogicalExtensionCodec for CsvLogicalExtensionCodec { _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, _ctx: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") @@ -292,16 +290,14 @@ impl LogicalExtensionCodec for JsonLogicalExtensionCodec { _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, _ctx: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") @@ -591,16 +587,14 @@ impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, _ctx: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") @@ -681,16 +675,14 @@ impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, _ctx: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") @@ -741,16 +733,14 @@ impl LogicalExtensionCodec for AvroLogicalExtensionCodec { _table_ref: &TableReference, _schema: arrow::datatypes::SchemaRef, _cts: &datafusion::prelude::SessionContext, - ) -> datafusion_common::Result< - std::sync::Arc, - > { + ) -> datafusion_common::Result> { not_impl_err!("Method not implemented") } fn try_encode_table_provider( &self, _table_ref: &TableReference, - _node: std::sync::Arc, + _node: Arc, _buf: &mut Vec, ) -> datafusion_common::Result<()> { not_impl_err!("Method not implemented") diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 80caaafad6f6..75881a421d17 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -69,8 +69,7 @@ impl SerializerRegistry for MockSerializerRegistry { &self, name: &str, bytes: &[u8], - ) -> Result> - { + ) -> Result> { if name == "MockUserDefinedLogicalPlan" { MockUserDefinedLogicalPlan::deserialize(bytes) } else {