From 33aea7e6407406a9ec2cbc27323e6f499ebbcd76 Mon Sep 17 00:00:00 2001 From: Jeffrey Vo Date: Thu, 16 May 2024 03:23:21 +1000 Subject: [PATCH 01/11] Remove `file_type()` from `FileFormat` (#10499) --- datafusion/core/src/datasource/file_format/arrow.rs | 6 +----- datafusion/core/src/datasource/file_format/avro.rs | 5 ----- datafusion/core/src/datasource/file_format/csv.rs | 9 +++------ datafusion/core/src/datasource/file_format/json.rs | 6 +----- datafusion/core/src/datasource/file_format/mod.rs | 5 +---- datafusion/core/src/datasource/file_format/parquet.rs | 6 +----- 6 files changed, 7 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 9d58465191e1..8c6790541597 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -40,7 +40,7 @@ use arrow::ipc::reader::FileReader; use arrow::ipc::writer::IpcWriteOptions; use arrow::ipc::{root_as_message, CompressionType}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use datafusion_common::{not_impl_err, DataFusionError, FileType, Statistics}; +use datafusion_common::{not_impl_err, DataFusionError, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; @@ -136,10 +136,6 @@ impl FileFormat for ArrowFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::ARROW - } } /// Implements [`DataSink`] for writing to arrow_ipc files diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 132dae14c684..7b2c26a2c4f9 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::FileType; use datafusion_physical_expr::PhysicalExpr; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; @@ -89,10 +88,6 @@ impl FileFormat for AvroFormat { let exec = AvroExec::new(conf); Ok(Arc::new(exec)) } - - fn file_type(&self) -> FileType { - FileType::AVRO - } } #[cfg(test)] diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 17bc7aafce85..ae5ac52025cf 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -42,7 +42,7 @@ use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Fields, Schema}; use datafusion_common::config::CsvOptions; use datafusion_common::file_options::csv_writer::CsvWriterOptions; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; +use datafusion_common::{exec_err, not_impl_err, DataFusionError}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; @@ -280,10 +280,6 @@ impl FileFormat for CsvFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::CSV - } } impl CsvFormat { @@ -549,8 +545,9 @@ mod tests { use arrow::compute::concat_batches; use datafusion_common::cast::as_string_array; + use datafusion_common::internal_err; use datafusion_common::stats::Precision; - use datafusion_common::{internal_err, GetExt}; + use datafusion_common::{FileType, GetExt}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::{col, lit}; diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 9f526e1c87b4..6e6c79848594 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -43,7 +43,7 @@ use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow_array::RecordBatch; use datafusion_common::config::JsonOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; -use datafusion_common::{not_impl_err, FileType}; +use datafusion_common::not_impl_err; use datafusion_execution::TaskContext; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use datafusion_physical_plan::metrics::MetricsSet; @@ -184,10 +184,6 @@ impl FileFormat for JsonFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::JSON - } } impl Default for JsonSerializer { diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index fdb89a264951..243a91b7437b 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -41,7 +41,7 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use datafusion_common::{not_impl_err, FileType}; +use datafusion_common::not_impl_err; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; use async_trait::async_trait; @@ -104,9 +104,6 @@ pub trait FileFormat: Send + Sync + fmt::Debug { ) -> Result> { not_impl_err!("Writer not implemented for this format") } - - /// Returns the FileType corresponding to this FileFormat - fn file_type(&self) -> FileType; } #[cfg(test)] diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index fa379eb5b445..8182ced6f228 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -47,7 +47,7 @@ use datafusion_common::config::TableParquetOptions; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::stats::Precision; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_err, DataFusionError, FileType, + exec_err, internal_datafusion_err, not_impl_err, DataFusionError, }; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; @@ -286,10 +286,6 @@ impl FileFormat for ParquetFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::PARQUET - } } fn summarize_min_max( From 405a5f60406f8e2757d2d6fb8e3cc46094aa8ed7 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 15 May 2024 13:28:57 -0400 Subject: [PATCH 02/11] Minor: add a test for `current_time` (no args) (#10509) --- datafusion/sqllogictest/test_files/timestamps.slt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 13fb8fba0d31..5f75bca4f0fa 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2795,3 +2795,9 @@ SELECT '2000-12-01 04:04:12' AT TIME ZONE 'America/New York'; # abbreviated timezone is not supported statement error SELECT '2023-03-12 02:00:00' AT TIME ZONE 'EDT'; + +# Test current_time without parentheses +query B +select current_time = current_time; +---- +true From cb88a79a1aeffbe543778da41c27d9c363be4dad Mon Sep 17 00:00:00 2001 From: shanretoo Date: Thu, 16 May 2024 01:42:01 +0800 Subject: [PATCH 03/11] fix: parsing timestamp with date format (#10476) * fix parsing timestamp with date format * add test in dates.slt --- datafusion/functions/src/datetime/common.rs | 4 +++- datafusion/functions/src/datetime/to_timestamp.rs | 4 ++++ datafusion/sqllogictest/test_files/dates.slt | 6 ++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index f0689ffd64e9..4f48ab188403 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -93,7 +93,9 @@ pub(crate) fn string_to_datetime_formatted( if let Err(e) = &dt { // no timezone or other failure, try without a timezone - let ndt = parsed.to_naive_datetime_with_offset(0); + let ndt = parsed + .to_naive_datetime_with_offset(0) + .or_else(|_| parsed.to_naive_date().map(|nd| nd.into())); if let Err(e) = &ndt { return Err(err(&e.to_string())); } diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index a7bcca62944c..af878b4505bc 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -670,6 +670,10 @@ mod tests { parse_timestamp_formatted("09-08-2020 13/42/29", "%m-%d-%Y %H/%M/%S") .unwrap() ); + assert_eq!( + 1642896000000000000, + parse_timestamp_formatted("2022-01-23", "%Y-%m-%d").unwrap() + ); } fn parse_timestamp_formatted(s: &str, format: &str) -> Result { diff --git a/datafusion/sqllogictest/test_files/dates.slt b/datafusion/sqllogictest/test_files/dates.slt index 32c0bd14e7cc..e21637bd8913 100644 --- a/datafusion/sqllogictest/test_files/dates.slt +++ b/datafusion/sqllogictest/test_files/dates.slt @@ -224,5 +224,11 @@ SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', query error function unsupported data type at index 1: SELECT to_date(t.ts, make_array('%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+')) from ts_utf8_data as t +# verify to_date with format +query D +select to_date('2022-01-23', '%Y-%m-%d'); +---- +2022-01-23 + statement ok drop table ts_utf8_data From bed57df3e8dc04961755da593d345c61d0e1be39 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 15 May 2024 21:52:25 +0300 Subject: [PATCH 04/11] [MINOR]: Move pipeline checker rule to the end (#10502) * Move pipeline checker to last * Update slt --- datafusion/core/src/physical_optimizer/optimizer.rs | 10 +++++----- datafusion/sqllogictest/test_files/explain.slt | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 08cbf68fa617..416985983dfe 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -112,11 +112,6 @@ impl PhysicalOptimizer { // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), - // The PipelineChecker rule will reject non-runnable query plans that use - // pipeline-breaking operators on infinite input(s). The rule generates a - // diagnostic error message when this happens. It makes no changes to the - // given query plan; i.e. it only acts as a final gatekeeping rule. - Arc::new(PipelineChecker::new()), // The aggregation limiter will try to find situations where the accumulator count // is not tied to the cardinality, i.e. when the output of the aggregation is passed // into an `order by max(x) limit y`. In this case it will copy the limit value down @@ -129,6 +124,11 @@ impl PhysicalOptimizer { // are not present, the load of executors such as join or union will be // reduced by narrowing their input tables. Arc::new(ProjectionPushdown::new()), + // The PipelineChecker rule will reject non-runnable query plans that use + // pipeline-breaking operators on infinite input(s). The rule generates a + // diagnostic error message when this happens. It makes no changes to the + // given query plan; i.e. it only acts as a final gatekeeping rule. + Arc::new(PipelineChecker::new()), ]; Self::with_rules(rules) diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 3a4ac747ebd6..92c537f975ad 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -252,9 +252,9 @@ physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] @@ -311,9 +311,9 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -348,9 +348,9 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan 01)GlobalLimitExec: skip=0, fetch=10 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 From ea9c32540870615764f5e8ee1531b1c70dd27eed Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 15 May 2024 14:52:40 -0400 Subject: [PATCH 05/11] Minor: Extract parent/child limit calculation into a function, improve docs (#10501) * Minor: Extract parent/child limit calculation into a function, improve docs * Update datafusion/optimizer/src/push_down_limit.rs Co-authored-by: Oleks V --------- Co-authored-by: Oleks V --- datafusion/optimizer/src/push_down_limit.rs | 116 +++++++++++++------- 1 file changed, 77 insertions(+), 39 deletions(-) diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 1af246fc556d..9190881335af 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -17,6 +17,7 @@ //! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan +use std::cmp::min; use std::sync::Arc; use crate::optimizer::ApplyOrder; @@ -56,47 +57,12 @@ impl OptimizerRule for PushDownLimit { if let LogicalPlan::Limit(child) = &*limit.input { // Merge the Parent Limit and the Child Limit. - - // Case 0: Parent and Child are disjoint. (child_fetch <= skip) - // Before merging: - // |........skip........|---fetch-->| Parent Limit - // |...child_skip...|---child_fetch-->| Child Limit - // After merging: - // |.........(child_skip + skip).........| - // Before merging: - // |...skip...|------------fetch------------>| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---(child_fetch - skip)-->| - - // Case 1: Parent is beyond the range of Child. (skip < child_fetch <= skip + fetch) - // Before merging: - // |...skip...|------------fetch------------>| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---(child_fetch - skip)-->| - - // Case 2: Parent is in the range of Child. (skip + fetch < child_fetch) - // Before merging: - // |...skip...|---fetch-->| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---fetch-->| - let parent_skip = limit.skip; - let new_fetch = match (limit.fetch, child.fetch) { - (Some(fetch), Some(child_fetch)) => { - Some(min(fetch, child_fetch.saturating_sub(parent_skip))) - } - (Some(fetch), None) => Some(fetch), - (None, Some(child_fetch)) => { - Some(child_fetch.saturating_sub(parent_skip)) - } - (None, None) => None, - }; + let (skip, fetch) = + combine_limit(limit.skip, limit.fetch, child.skip, child.fetch); let plan = LogicalPlan::Limit(Limit { - skip: child.skip + parent_skip, - fetch: new_fetch, + skip, + fetch, input: Arc::new((*child.input).clone()), }); return self @@ -217,6 +183,78 @@ impl OptimizerRule for PushDownLimit { } } +/// Combines two limits into a single +/// +/// Returns the combined limit `(skip, fetch)` +/// +/// # Case 0: Parent and Child are disjoint. (`child_fetch <= skip`) +/// +/// ```text +/// Before merging: +/// |........skip........|---fetch-->| Parent Limit +/// |...child_skip...|---child_fetch-->| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |.........(child_skip + skip).........| +/// ``` +/// +/// Before merging: +/// ```text +/// |...skip...|------------fetch------------>| Parent Limit +/// |...child_skip...|-------------child_fetch------------>| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---(child_fetch - skip)-->| +/// ``` +/// +/// # Case 1: Parent is beyond the range of Child. (`skip < child_fetch <= skip + fetch`) +/// +/// Before merging: +/// ```text +/// |...skip...|------------fetch------------>| Parent Limit +/// |...child_skip...|-------------child_fetch------------>| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---(child_fetch - skip)-->| +/// ``` +/// +/// # Case 2: Parent is in the range of Child. (`skip + fetch < child_fetch`) +/// Before merging: +/// ```text +/// |...skip...|---fetch-->| Parent Limit +/// |...child_skip...|-------------child_fetch------------>| Child Limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---fetch-->| +/// ``` +fn combine_limit( + parent_skip: usize, + parent_fetch: Option, + child_skip: usize, + child_fetch: Option, +) -> (usize, Option) { + let combined_skip = child_skip.saturating_add(parent_skip); + + let combined_fetch = match (parent_fetch, child_fetch) { + (Some(parent_fetch), Some(child_fetch)) => { + Some(min(parent_fetch, child_fetch.saturating_sub(parent_skip))) + } + (Some(parent_fetch), None) => Some(parent_fetch), + (None, Some(child_fetch)) => Some(child_fetch.saturating_sub(parent_skip)), + (None, None) => None, + }; + + (combined_skip, combined_fetch) +} + fn push_down_join(join: &Join, limit: usize) -> Option { use JoinType::*; From 8199e9e6601d91320e395b43ba3a005ae7ba4816 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Thu, 16 May 2024 02:53:36 +0800 Subject: [PATCH 06/11] Fix window expr deserialization (#10506) * Fix window expr deserialization * Improve naming and doc * Update window test --- .../core/tests/fuzz_cases/window_fuzz.rs | 34 ++----------------- datafusion/physical-plan/src/windows/mod.rs | 26 ++++++++++++++ .../proto/src/physical_plan/from_proto.rs | 12 ++++--- .../tests/cases/roundtrip_physical_plan.rs | 3 +- 4 files changed, 38 insertions(+), 37 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 2514324a9541..fe0c408dc114 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,11 +22,10 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, WindowAggExec, + create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec, }; use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; use datafusion::physical_plan::{collect, InputOrderMode}; @@ -40,7 +39,6 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use itertools::Itertools; use test_utils::add_empty_batches; use hashbrown::HashMap; @@ -276,7 +274,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { }; let extended_schema = - schema_add_window_fields(&args, &schema, &window_fn, fn_name)?; + schema_add_window_field(&args, &schema, &window_fn, fn_name)?; let window_expr = create_window_expr( &window_fn, @@ -683,7 +681,7 @@ async fn run_window_test( exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } - let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?; + let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; let usual_window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( @@ -754,32 +752,6 @@ async fn run_window_test( Ok(()) } -// The planner has fully updated schema before calling the `create_window_expr` -// Replicate the same for this test -fn schema_add_window_fields( - args: &[Arc], - schema: &Arc, - window_fn: &WindowFunctionDefinition, - fn_name: &str, -) -> Result> { - let data_types = args - .iter() - .map(|e| e.clone().as_ref().data_type(schema)) - .collect::>>()?; - let window_expr_return_type = window_fn.return_type(&data_types)?; - let mut window_fields = schema - .fields() - .iter() - .map(|f| f.as_ref().clone()) - .collect_vec(); - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - true, - )]); - Ok(Arc::new(Schema::new(window_fields))) -} - /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d1223f78808c..42c630741cc9 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -42,6 +42,7 @@ use datafusion_physical_expr::{ window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr}, AggregateExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, }; +use itertools::Itertools; mod bounded_window_agg_exec; mod window_agg_exec; @@ -52,6 +53,31 @@ pub use datafusion_physical_expr::window::{ }; pub use window_agg_exec::WindowAggExec; +/// Build field from window function and add it into schema +pub fn schema_add_window_field( + args: &[Arc], + schema: &Schema, + window_fn: &WindowFunctionDefinition, + fn_name: &str, +) -> Result> { + let data_types = args + .iter() + .map(|e| e.clone().as_ref().data_type(schema)) + .collect::>>()?; + let window_expr_return_type = window_fn.return_type(&data_types)?; + let mut window_fields = schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect_vec(); + window_fields.extend_from_slice(&[Field::new( + fn_name, + window_expr_return_type, + false, + )]); + Ok(Arc::new(Schema::new(window_fields))) +} + /// Create a physical expression for window function #[allow(clippy::too_many_arguments)] pub fn create_window_expr( diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index c907e991fb86..a290f30586ce 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -40,7 +40,7 @@ use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, }; -use datafusion::physical_plan::windows::create_window_expr; +use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{ ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; @@ -155,14 +155,18 @@ pub fn parse_physical_window_expr( ) })?; + let fun: WindowFunctionDefinition = convert_required!(proto.window_function)?; + let name = proto.name.clone(); + let extended_schema = + schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; create_window_expr( - &convert_required!(proto.window_function)?, - proto.name.clone(), + &fun, + name, &window_node_expr, &partition_by, &order_by, Arc::new(window_frame), - input_schema, + &extended_schema, false, ) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 30a28081edff..dd8e450d3165 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -253,8 +253,7 @@ fn roundtrip_nested_loop_join() -> Result<()> { fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); - let field_c = Field::new("FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); let window_frame = WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range, From a331b36a245c8c31f28b7b08af55cfd01c5d537a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 15 May 2024 15:24:30 -0400 Subject: [PATCH 07/11] Update substrait requirement from 0.32.0 to 0.33.3 (#10516) Updates the requirements on [substrait](https://github.com/substrait-io/substrait-rs) to permit the latest version. - [Release notes](https://github.com/substrait-io/substrait-rs/releases) - [Changelog](https://github.com/substrait-io/substrait-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/substrait-io/substrait-rs/compare/v0.32.0...v0.33.3) --- updated-dependencies: - dependency-name: substrait dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/substrait/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index dce8ce10b587..e4be6e68ff16 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -39,7 +39,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.32.0" +substrait = "0.33.3" [dev-dependencies] tokio = { workspace = true } From c312ffe7d954563888a303beb8796848d20ff7c6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 15 May 2024 15:58:15 -0400 Subject: [PATCH 08/11] Stop copying LogicalPlan and Exprs in `TypeCoercion` (10% faster planning) (#10356) * Add `LogicalPlan::recompute_schema` for handling rewrite passes * Stop copying LogicalPlan and Exprs in `TypeCoercion` * Apply suggestions from code review Co-authored-by: Oleks V --------- Co-authored-by: Oleks V --- .../optimizer/src/analyzer/type_coercion.rs | 125 ++++++++++++------ 1 file changed, 88 insertions(+), 37 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 60b81aff9aaa..0f1f3ba7e729 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, @@ -31,8 +31,8 @@ use datafusion_expr::expr::{ self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, WindowFunction, }; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ comparison_coercion, get_input_types, like_coercion, @@ -52,6 +52,7 @@ use datafusion_expr::{ }; use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; #[derive(Default)] pub struct TypeCoercion {} @@ -68,26 +69,28 @@ impl AnalyzerRule for TypeCoercion { } fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(&DFSchema::empty(), &plan) + let empty_schema = DFSchema::empty(); + + let transformed_plan = plan + .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))? + .data; + + Ok(transformed_plan) } } +/// use the external schema to handle the correlated subqueries case +/// +/// Assumes that children have already been optimized fn analyze_internal( - // use the external schema to handle the correlated subqueries case external_schema: &DFSchema, - plan: &LogicalPlan, -) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(external_schema, p)) - .collect::>>()?; + plan: LogicalPlan, +) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); + let mut schema = merge_schema(plan.inputs()); - if let LogicalPlan::TableScan(ts) = plan { + if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -100,25 +103,75 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); - let mut expr_rewrite = TypeCoercionRewriter { schema: &schema }; - - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure aggregate names don't change: - // https://github.com/apache/datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; - - plan.with_new_exprs(new_expr, new_inputs) + let mut expr_rewrite = TypeCoercionRewriter::new(&schema); + + let name_preserver = NamePreserver::new(&plan); + // apply coercion rewrite all expressions in the plan individually + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + expr.rewrite(&mut expr_rewrite)? + .map_data(|expr| original_name.restore(expr)) + })? + // coerce join expressions specially + .map_data(|plan| expr_rewrite.coerce_joins(plan))? + // recompute the schema after the expressions have been rewritten as the types may have changed + .map_data(|plan| plan.recompute_schema()) } pub(crate) struct TypeCoercionRewriter<'a> { pub(crate) schema: &'a DFSchema, } +impl<'a> TypeCoercionRewriter<'a> { + fn new(schema: &'a DFSchema) -> Self { + Self { schema } + } + + /// Coerce join equality expressions + /// + /// Joins must be treated specially as their equality expressions are stored + /// as a parallel list of left and right expressions, rather than a single + /// equality expression + /// + /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored + /// as a list of `(t1.a, t2.b), (t1.x, t2.y)` + fn coerce_joins(&mut self, plan: LogicalPlan) -> Result { + let LogicalPlan::Join(mut join) = plan else { + return Ok(plan); + }; + + join.on = join + .on + .into_iter() + .map(|(lhs, rhs)| { + // coerce the arguments as though they were a single binary equality + // expression + let (lhs, rhs) = self.coerce_binary_op(lhs, Operator::Eq, rhs)?; + Ok((lhs, rhs)) + }) + .collect::>>()?; + + Ok(LogicalPlan::Join(join)) + } + + fn coerce_binary_op( + &self, + left: Expr, + op: Operator, + right: Expr, + ) -> Result<(Expr, Expr)> { + let (left_type, right_type) = get_input_types( + &left.get_type(self.schema)?, + &op, + &right.get_type(self.schema)?, + )?; + Ok(( + left.cast_to(&left_type, self.schema)?, + right.cast_to(&right_type, self.schema)?, + )) + } +} + impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { type Node = Expr; @@ -131,14 +184,15 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery, outer_ref_columns, }) => { - let new_plan = analyze_internal(self.schema, &subquery)?; + let new_plan = analyze_internal(self.schema, unwrap_arc(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data; Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), @@ -152,7 +206,8 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { subquery, negated, }) => { - let new_plan = analyze_internal(self.schema, &subquery.subquery)?; + let new_plan = + analyze_internal(self.schema, unwrap_arc(subquery.subquery))?.data; let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( @@ -221,15 +276,11 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (left_type, right_type) = get_input_types( - &left.get_type(self.schema)?, - &op, - &right.get_type(self.schema)?, - )?; + let (left, right) = self.coerce_binary_op(*left, op, *right)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, self.schema)?), + Box::new(left), op, - Box::new(right.cast_to(&right_type, self.schema)?), + Box::new(right), )))) } Expr::Between(Between { From eddec8e78865c0f17bd089af641492b1d8e8a411 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 16 May 2024 07:43:17 +0800 Subject: [PATCH 09/11] Implement unparse `IS_NULL` to String and enhance the tests (#10529) * implement unparse is_null and add test * format the code --- datafusion/sql/src/unparser/expr.rs | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 804fa6d306b4..23e3d9ab3594 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -391,7 +391,9 @@ impl Unparser<'_> { Expr::ScalarVariable(_, _) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") } - Expr::IsNull(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + Expr::IsNull(expr) => { + Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?))) + } Expr::IsNotFalse(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), Expr::GetIndexedField(_) => { not_impl_err!("Unsupported Expr conversion: {expr:?}") @@ -863,7 +865,7 @@ mod tests { use datafusion_expr::{ case, col, exists, expr::{AggregateFunction, AggregateFunctionDefinition}, - lit, not, not_exists, table_scan, wildcard, ColumnarValue, ScalarUDF, + lit, not, not_exists, table_scan, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; @@ -933,6 +935,14 @@ mod tests { .otherwise(lit(ScalarValue::Null))?, r#"CASE "a" WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END"#, ), + ( + when(col("a").is_null(), lit(true)).otherwise(lit(false))?, + r#"CASE WHEN "a" IS NULL THEN true ELSE false END"#, + ), + ( + when(col("a").is_not_null(), lit(true)).otherwise(lit(false))?, + r#"CASE WHEN "a" IS NOT NULL THEN true ELSE false END"#, + ), ( Expr::Cast(Cast { expr: Box::new(col("a")), @@ -959,6 +969,18 @@ mod tests { ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]), r#"dummy_udf("a", "b")"#, ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()) + .call(vec![col("a"), col("b")]) + .is_null(), + r#"dummy_udf("a", "b") IS NULL"#, + ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()) + .call(vec![col("a"), col("b")]) + .is_not_null(), + r#"dummy_udf("a", "b") IS NOT NULL"#, + ), ( Expr::Like(Like { negated: true, @@ -1081,6 +1103,7 @@ mod tests { r#"COUNT(*) OVER (ORDER BY "a" DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, ), (col("a").is_not_null(), r#""a" IS NOT NULL"#), + (col("a").is_null(), r#""a" IS NULL"#), ( (col("a") + col("b")).gt(lit(4)).is_true(), r#"(("a" + "b") > 4) IS TRUE"#, From 626c6bc8bf9b10aaf416b7494ae2c31c14cec5ce Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 16 May 2024 07:56:31 +0800 Subject: [PATCH 10/11] support merge batch for distinct array aggregate (#10526) Signed-off-by: jayzhan211 --- .../src/aggregate/array_agg_distinct.rs | 11 ++- .../sqllogictest/test_files/aggregate.slt | 67 +++++++++++++++++++ 2 files changed, 72 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index b8671c39a943..244a44acdcb5 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -153,12 +153,11 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - let array = &states[0]; - - assert_eq!(array.len(), 1, "state array should only include 1 row!"); - // Unwrap outer ListArray then do update batch - let inner_array = array.as_list::().value(0); - self.update_batch(&[inner_array]) + states[0] + .as_list::() + .iter() + .flatten() + .try_for_each(|val| self.update_batch(&[val])) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 40d66f9b52ce..78421d0b6431 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -198,6 +198,73 @@ statement error This feature is not implemented: LIMIT not supported in ARRAY_AG SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 +# Test distinct aggregate function with merge batch +query II +with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 + ---- The order is non-deterministic, verify with length +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +3 1 + +# It has only AggregateExec with FinalPartitioned mode, so `merge_batch` is used +# If the plan is changed, whether the `merge_batch` is used should be verified to ensure the test coverage +query TT +explain with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +logical_plan +01)Projection: array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1)) +02)--Aggregate: groupBy=[[a.id]], aggr=[[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))]] +03)----SubqueryAlias: a +04)------SubqueryAlias: a +05)--------Union +06)----------Projection: Int64(1) AS id, Int64(2) AS foo +07)------------EmptyRelation +08)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +09)------------EmptyRelation +10)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +11)------------EmptyRelation +12)----------Projection: Int64(1) AS id, Int64(3) AS foo +13)------------EmptyRelation +14)----------Projection: Int64(1) AS id, Int64(2) AS foo +15)------------EmptyRelation +physical_plan +01)ProjectionExec: expr=[array_length(ARRAY_AGG(DISTINCT a.foo)@1) as array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1))@2 as SUM(DISTINCT Int64(1))] +02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))] +06)----------UnionExec +07)------------ProjectionExec: expr=[1 as id, 2 as foo] +08)--------------PlaceholderRowExec +09)------------ProjectionExec: expr=[1 as id, NULL as foo] +10)--------------PlaceholderRowExec +11)------------ProjectionExec: expr=[1 as id, NULL as foo] +12)--------------PlaceholderRowExec +13)------------ProjectionExec: expr=[1 as id, 3 as foo] +14)--------------PlaceholderRowExec +15)------------ProjectionExec: expr=[1 as id, 2 as foo] +16)--------------PlaceholderRowExec + + # FIX: custom absolute values # csv_query_avg_multi_batch From 5a8348f7111b2b0d39f2bd3fd1b1534338113b9f Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 16 May 2024 08:30:58 +0800 Subject: [PATCH 11/11] UDAF: Extend more args to `state_fields` and `groups_accumulator_supported` and introduce `ReversedUDAF` (#10525) * extends args Signed-off-by: jayzhan211 * reuse accumulator args Signed-off-by: jayzhan211 * fix example Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 15 ++--- .../examples/simplify_udaf_expression.rs | 11 +--- .../user_defined/user_defined_aggregates.rs | 2 +- datafusion/expr/src/expr_fn.rs | 8 +-- datafusion/expr/src/function.rs | 51 ++++++++++++----- datafusion/expr/src/udaf.rs | 57 +++++++++++-------- .../functions-aggregate/src/covariance.rs | 22 +++---- .../functions-aggregate/src/first_last.rs | 15 ++--- .../simplify_expressions/expr_simplifier.rs | 6 +- .../physical-expr-common/src/aggregate/mod.rs | 43 ++++++++++---- 10 files changed, 128 insertions(+), 102 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 342a23b6e73d..cf284472212f 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -31,8 +31,8 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl, - GroupsAccumulator, Signature, + function::{AccumulatorArgs, StateFieldsArgs}, + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, }; /// This example shows how to use the full AggregateUDFImpl API to implement a user @@ -92,21 +92,16 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields( - &self, - _name: &str, - value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", value_type, true), + Field::new("prod", args.return_type.clone(), true), Field::new("n", DataType::UInt32, true), ]) } /// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` /// which is used for cases when there are grouping columns in the query - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs index 92deb20272e4..08b6bcab0190 100644 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -17,7 +17,7 @@ use arrow_schema::{Field, Schema}; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use datafusion_expr::function::AggregateFunctionSimplification; +use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs}; use datafusion_expr::simplify::SimplifyInfo; use std::{any::Any, sync::Arc}; @@ -70,16 +70,11 @@ impl AggregateUDFImpl for BetterAvgUdaf { unimplemented!("should not be invoked") } - fn state_fields( - &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { unimplemented!("should not be invoked") } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 8f02fb30b013..d199f04ba781 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -725,7 +725,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { panic!("accumulator shouldn't invoke"); } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1d976a12cc4f..64763a973687 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -23,6 +23,7 @@ use crate::expr::{ }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, + StateFieldsArgs, }; use crate::{ aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, @@ -690,12 +691,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_fields( - &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(self.state_fields.clone()) } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 4e4d77924a9d..714cfa1af671 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -19,7 +19,7 @@ use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; use std::sync::Arc; @@ -41,11 +41,14 @@ pub type ReturnTypeFunction = /// [`AccumulatorArgs`] contains information about how an aggregate /// function was called, including the types of its arguments and any optional /// ordering expressions. +#[derive(Debug)] pub struct AccumulatorArgs<'a> { /// The return type of the aggregate function. pub data_type: &'a DataType, + /// The schema of the input arguments pub schema: &'a Schema, + /// Whether to ignore nulls. /// /// SQL allows the user to specify `IGNORE NULLS`, for example: @@ -66,22 +69,40 @@ pub struct AccumulatorArgs<'a> { /// /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. pub sort_exprs: &'a [Expr], + + /// Whether the aggregate function is distinct. + /// + /// ```sql + /// SELECT COUNT(DISTINCT column1) FROM t; + /// ``` + pub is_distinct: bool, + + /// The input type of the aggregate function. + pub input_type: &'a DataType, + + /// The number of arguments the aggregate function takes. + pub args_num: usize, } -impl<'a> AccumulatorArgs<'a> { - pub fn new( - data_type: &'a DataType, - schema: &'a Schema, - ignore_nulls: bool, - sort_exprs: &'a [Expr], - ) -> Self { - Self { - data_type, - schema, - ignore_nulls, - sort_exprs, - } - } +/// [`StateFieldsArgs`] contains information about the fields that an +/// aggregate function's accumulator should have. Used for [`AggregateUDFImpl::state_fields`]. +/// +/// [`AggregateUDFImpl::state_fields`]: crate::udaf::AggregateUDFImpl::state_fields +pub struct StateFieldsArgs<'a> { + /// The name of the aggregate function. + pub name: &'a str, + + /// The input type of the aggregate function. + pub input_type: &'a DataType, + + /// The return type of the aggregate function. + pub return_type: &'a DataType, + + /// The ordering fields of the aggregate function. + pub ordering_fields: &'a [Field], + + /// Whether the aggregate function is distinct. + pub is_distinct: bool, } /// Factory that returns an accumulator for the given aggregate function. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 95121d78e7aa..4fd8d51679f0 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,7 +17,9 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::function::{AccumulatorArgs, AggregateFunctionSimplification}; +use crate::function::{ + AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, +}; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::{Accumulator, Expr}; @@ -177,18 +179,13 @@ impl AggregateUDF { /// for more details. /// /// This is used to support multi-phase aggregations - pub fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - self.inner.state_fields(name, value_type, ordering_fields) + pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.inner.state_fields(args) } /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. - pub fn groups_accumulator_supported(&self) -> bool { - self.inner.groups_accumulator_supported() + pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + self.inner.groups_accumulator_supported(args) } /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details. @@ -232,7 +229,7 @@ where /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; -/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs}; +/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; /// # use arrow::datatypes::Schema; /// # use arrow::datatypes::Field; /// #[derive(Debug, Clone)] @@ -261,9 +258,9 @@ where /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } -/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec) -> Result> { +/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// Ok(vec![ -/// Field::new("value", value_type, true), +/// Field::new("value", args.return_type.clone(), true), /// Field::new("ordering", DataType::UInt32, true) /// ]) /// } @@ -319,19 +316,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// The name of the fields must be unique within the query and thus should /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - let value_fields = vec![Field::new( - format_state_name(name, "value"), - value_type, + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![Field::new( + format_state_name(args.name, "value"), + args.return_type.clone(), true, )]; - Ok(value_fields.into_iter().chain(ordering_fields).collect()) + Ok(fields + .into_iter() + .chain(args.ordering_fields.to_vec()) + .collect()) } /// If the aggregate expression has a specialized @@ -344,7 +339,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// `Self::accumulator` for certain queries, such as when this aggregate is /// used as a window function or when there no GROUP BY columns in the /// query. - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { false } @@ -389,6 +384,20 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn simplify(&self) -> Option { None } + + /// Returns the reverse expression of the aggregate function. + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::NotSupported + } +} + +pub enum ReversedUDAF { + /// The expression is the same as the original expression, like SUM, COUNT + Identical, + /// The expression does not support reverse calculation, like ArrayAgg + NotSupported, + /// The expression is different from the original expression + Reversed(Arc), } /// AggregateUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index 1210e1529dbb..6f03b256fd9f 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -30,8 +30,10 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{ - function::AccumulatorArgs, type_coercion::aggregates::NUMERICS, - utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, + function::{AccumulatorArgs, StateFieldsArgs}, + type_coercion::aggregates::NUMERICS, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Signature, Volatility, }; use datafusion_physical_expr_common::aggregate::stats::StatsType; @@ -101,12 +103,8 @@ impl AggregateUDFImpl for CovarianceSample { Ok(DataType::Float64) } - fn state_fields( - &self, - name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean1"), DataType::Float64, true), @@ -176,12 +174,8 @@ impl AggregateUDFImpl for CovariancePopulation { Ok(DataType::Float64) } - fn state_fields( - &self, - name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; Ok(vec![ Field::new(format_state_name(name, "count"), DataType::UInt64, true), Field::new(format_state_name(name, "mean1"), DataType::Float64, true), diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index e3b685e90376..5d3d48344014 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ @@ -147,18 +147,13 @@ impl AggregateUDFImpl for FirstValue { .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( - format_state_name(name, "first_value"), - value_type, + format_state_name(args.name, "first_value"), + args.return_type.clone(), true, )]; - fields.extend(ordering_fields); + fields.extend(args.ordering_fields.to_vec()); fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 55052542a8bf..455d659fb25e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1759,7 +1759,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { mod tests { use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ - function::AggregateFunctionSimplification, interval_arithmetic::Interval, *, + function::{AccumulatorArgs, AggregateFunctionSimplification}, + interval_arithmetic::Interval, + *, }; use std::{ collections::HashMap, @@ -3783,7 +3785,7 @@ mod tests { unimplemented!("not needed for tests") } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { unimplemented!("not needed for testing") } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 05641b373b72..da24f335b2f8 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -20,6 +20,7 @@ pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, @@ -74,6 +75,7 @@ pub fn create_aggregate_expr( ignore_nulls, ordering_fields, is_distinct, + input_type: input_exprs_types[0].clone(), })) } @@ -166,6 +168,7 @@ pub struct AggregateFunctionExpr { ignore_nulls: bool, ordering_fields: Vec, is_distinct: bool, + input_type: DataType, } impl AggregateFunctionExpr { @@ -191,11 +194,15 @@ impl AggregateExpr for AggregateFunctionExpr { } fn state_fields(&self) -> Result> { - self.fun.state_fields( - self.name(), - self.data_type.clone(), - self.ordering_fields.clone(), - ) + let args = StateFieldsArgs { + name: &self.name, + input_type: &self.input_type, + return_type: &self.data_type, + ordering_fields: &self.ordering_fields, + is_distinct: self.is_distinct, + }; + + self.fun.state_fields(args) } fn field(&self) -> Result { @@ -203,12 +210,15 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - let acc_args = AccumulatorArgs::new( - &self.data_type, - &self.schema, - self.ignore_nulls, - &self.sort_exprs, - ); + let acc_args = AccumulatorArgs { + data_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + sort_exprs: &self.sort_exprs, + is_distinct: self.is_distinct, + input_type: &self.input_type, + args_num: self.args.len(), + }; self.fun.accumulator(acc_args) } @@ -273,7 +283,16 @@ impl AggregateExpr for AggregateFunctionExpr { } fn groups_accumulator_supported(&self) -> bool { - self.fun.groups_accumulator_supported() + let args = AccumulatorArgs { + data_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + sort_exprs: &self.sort_exprs, + is_distinct: self.is_distinct, + input_type: &self.input_type, + args_num: self.args.len(), + }; + self.fun.groups_accumulator_supported(args) } fn create_groups_accumulator(&self) -> Result> {