diff --git a/.asf.yaml b/.asf.yaml index e2ad0198a303..c605a4692974 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -45,7 +45,9 @@ github: features: issues: true protected_branches: - main: { } + main: + required_pull_request_reviews: + required_approving_review_count: 1 # publishes the content of the `asf-site` branch to # https://arrow.apache.org/datafusion/ diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index ec66a1270cb8..442e6e4009f6 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,4 +1,4 @@ -# Which issue does this PR close? +## Which issue does this PR close? -# What changes are included in this PR? +## What changes are included in this PR? -# Are these changes tested? +## Are these changes tested? -# Are there any user-facing changes? +## Are there any user-facing changes? `col = !true` if op != Operator::Eq && op != Operator::NotEq { - return Err(DataFusionError::Plan( - "Not with operator other than Eq / NotEq is not supported".to_string(), - )); + return plan_err!("Not with operator other than Eq / NotEq is not supported"); } if not .arg() @@ -588,14 +584,10 @@ fn rewrite_expr_to_prunable( let right = Arc::new(phys_expr::NotExpr::new(scalar_expr.clone())); Ok((left, reverse_operator(op)?, right)) } else { - Err(DataFusionError::Plan(format!( - "Not with complex expression {column_expr:?} is not supported" - ))) + plan_err!("Not with complex expression {column_expr:?} is not supported") } } else { - Err(DataFusionError::Plan(format!( - "column expression {column_expr:?} is not supported" - ))) + plan_err!("column expression {column_expr:?} is not supported") } } @@ -630,9 +622,9 @@ fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Re ) { Ok(()) } else { - Err(DataFusionError::Plan(format!( + plan_err!( "Try Cast/Cast with from type {from_type} to type {to_type} is not supported" - ))) + ) } } @@ -841,85 +833,85 @@ fn build_predicate_expression( fn build_statistics_expr( expr_builder: &mut PruningExpressionBuilder, ) -> Result> { - let statistics_expr: Arc = - match expr_builder.op() { - Operator::NotEq => { - // column != literal => (min, max) = literal => - // !(min != literal && max != literal) ==> - // min != literal || literal != max - let min_column_expr = expr_builder.min_column_expr()?; - let max_column_expr = expr_builder.max_column_expr()?; - Arc::new(phys_expr::BinaryExpr::new( - Arc::new(phys_expr::BinaryExpr::new( - min_column_expr, - Operator::NotEq, - expr_builder.scalar_expr().clone(), - )), - Operator::Or, - Arc::new(phys_expr::BinaryExpr::new( - expr_builder.scalar_expr().clone(), - Operator::NotEq, - max_column_expr, - )), - )) - } - Operator::Eq => { - // column = literal => (min, max) = literal => min <= literal && literal <= max - // (column / 2) = 4 => (column_min / 2) <= 4 && 4 <= (column_max / 2) - let min_column_expr = expr_builder.min_column_expr()?; - let max_column_expr = expr_builder.max_column_expr()?; + let statistics_expr: Arc = match expr_builder.op() { + Operator::NotEq => { + // column != literal => (min, max) = literal => + // !(min != literal && max != literal) ==> + // min != literal || literal != max + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + Arc::new(phys_expr::BinaryExpr::new( Arc::new(phys_expr::BinaryExpr::new( - Arc::new(phys_expr::BinaryExpr::new( - min_column_expr, - Operator::LtEq, - expr_builder.scalar_expr().clone(), - )), - Operator::And, - Arc::new(phys_expr::BinaryExpr::new( - expr_builder.scalar_expr().clone(), - Operator::LtEq, - max_column_expr, - )), - )) - } - Operator::Gt => { - // column > literal => (min, max) > literal => max > literal - Arc::new(phys_expr::BinaryExpr::new( - expr_builder.max_column_expr()?, - Operator::Gt, + min_column_expr, + Operator::NotEq, expr_builder.scalar_expr().clone(), - )) - } - Operator::GtEq => { - // column >= literal => (min, max) >= literal => max >= literal + )), + Operator::Or, Arc::new(phys_expr::BinaryExpr::new( - expr_builder.max_column_expr()?, - Operator::GtEq, expr_builder.scalar_expr().clone(), - )) - } - Operator::Lt => { - // column < literal => (min, max) < literal => min < literal + Operator::NotEq, + max_column_expr, + )), + )) + } + Operator::Eq => { + // column = literal => (min, max) = literal => min <= literal && literal <= max + // (column / 2) = 4 => (column_min / 2) <= 4 && 4 <= (column_max / 2) + let min_column_expr = expr_builder.min_column_expr()?; + let max_column_expr = expr_builder.max_column_expr()?; + Arc::new(phys_expr::BinaryExpr::new( Arc::new(phys_expr::BinaryExpr::new( - expr_builder.min_column_expr()?, - Operator::Lt, + min_column_expr, + Operator::LtEq, expr_builder.scalar_expr().clone(), - )) - } - Operator::LtEq => { - // column <= literal => (min, max) <= literal => min <= literal + )), + Operator::And, Arc::new(phys_expr::BinaryExpr::new( - expr_builder.min_column_expr()?, - Operator::LtEq, expr_builder.scalar_expr().clone(), - )) - } - // other expressions are not supported - _ => return Err(DataFusionError::Plan( + Operator::LtEq, + max_column_expr, + )), + )) + } + Operator::Gt => { + // column > literal => (min, max) > literal => max > literal + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.max_column_expr()?, + Operator::Gt, + expr_builder.scalar_expr().clone(), + )) + } + Operator::GtEq => { + // column >= literal => (min, max) >= literal => max >= literal + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.max_column_expr()?, + Operator::GtEq, + expr_builder.scalar_expr().clone(), + )) + } + Operator::Lt => { + // column < literal => (min, max) < literal => min < literal + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.min_column_expr()?, + Operator::Lt, + expr_builder.scalar_expr().clone(), + )) + } + Operator::LtEq => { + // column <= literal => (min, max) <= literal => min <= literal + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.min_column_expr()?, + Operator::LtEq, + expr_builder.scalar_expr().clone(), + )) + } + // other expressions are not supported + _ => { + return plan_err!( "expressions other than (neq, eq, gt, gteq, lt, lteq) are not supported" - .to_string(), - )), - }; + ) + } + }; Ok(statistics_expr) } diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index 97fc79720bc2..644aea1f3812 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -56,7 +56,7 @@ use crate::physical_plan::{with_new_children_if_necessary, Distribution, Executi use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; use datafusion_common::utils::{get_at_indices, longest_consecutive_prefix}; -use datafusion_common::DataFusionError; +use datafusion_common::{plan_err, DataFusionError}; use datafusion_physical_expr::utils::{ convert_to_expr, get_indices_of_matching_exprs, ordering_satisfy, ordering_satisfy_requirement_concrete, @@ -597,17 +597,16 @@ fn analyze_window_sort_removal( sort_tree: &mut ExecTree, window_exec: &Arc, ) -> Result> { - let (window_expr, partition_keys) = if let Some(exec) = - window_exec.as_any().downcast_ref::() - { - (exec.window_expr(), &exec.partition_keys) - } else if let Some(exec) = window_exec.as_any().downcast_ref::() { - (exec.window_expr(), &exec.partition_keys) - } else { - return Err(DataFusionError::Plan( - "Expects to receive either WindowAggExec of BoundedWindowAggExec".to_string(), - )); - }; + let (window_expr, partition_keys) = + if let Some(exec) = window_exec.as_any().downcast_ref::() { + (exec.window_expr(), &exec.partition_keys) + } else if let Some(exec) = window_exec.as_any().downcast_ref::() { + (exec.window_expr(), &exec.partition_keys) + } else { + return plan_err!( + "Expects to receive either WindowAggExec of BoundedWindowAggExec" + ); + }; let partitionby_exprs = window_expr[0].partition_by(); let orderby_sort_keys = window_expr[0].order_by(); @@ -814,10 +813,7 @@ fn get_sort_exprs( sort_preserving_merge_exec.fetch(), )) } else { - Err(DataFusionError::Plan( - "Given ExecutionPlan is not a SortExec or a SortPreservingMergeExec" - .to_string(), - )) + plan_err!("Given ExecutionPlan is not a SortExec or a SortPreservingMergeExec") } } diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 108f753800f1..aaa2b6de2b6f 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -14,25 +14,28 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + +use std::sync::Arc; + use crate::physical_optimizer::utils::{add_sort_above, is_limit, is_union, is_window}; use crate::physical_plan::filter::FilterExec; -use crate::physical_plan::joins::utils::JoinSide; +use crate::physical_plan::joins::utils::{calculate_join_output_ordering, JoinSide}; use crate::physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; + use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::{ ordering_satisfy_requirement, requirements_compatible, }; use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; + use itertools::izip; -use std::ops::Deref; -use std::sync::Arc; /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total @@ -223,31 +226,24 @@ fn pushdown_requirement_to_children( let expr_source_side = expr_source_sides(&parent_required_expr, smj.join_type, left_columns_len); match expr_source_side { - Some(JoinSide::Left) if maintains_input_order[0] => { + Some(JoinSide::Left) => try_pushdown_requirements_to_join( + smj, + parent_required, + parent_required_expr, + JoinSide::Left, + ), + Some(JoinSide::Right) => { + let right_offset = + smj.schema().fields.len() - smj.right.schema().fields.len(); + let new_right_required = + shift_right_required(parent_required.ok_or_else(err)?, right_offset)?; + let new_right_required_expr = PhysicalSortRequirement::to_sort_exprs( + new_right_required.iter().cloned(), + ); try_pushdown_requirements_to_join( - plan, + smj, parent_required, - parent_required_expr, - JoinSide::Left, - ) - } - Some(JoinSide::Right) if maintains_input_order[1] => { - let new_right_required = match smj.join_type { - JoinType::Inner | JoinType::Right => shift_right_required( - parent_required.ok_or_else(err)?, - left_columns_len, - )?, - JoinType::RightSemi | JoinType::RightAnti => { - parent_required.ok_or_else(err)?.to_vec() - } - _ => Err(DataFusionError::Plan( - "Unexpected SortMergeJoin type here".to_string(), - ))?, - }; - try_pushdown_requirements_to_join( - plan, - Some(new_right_required.deref()), - parent_required_expr, + new_right_required_expr, JoinSide::Right, ) } @@ -316,39 +312,45 @@ fn determine_children_requirement( RequirementsCompatibility::NonCompatible } } - fn try_pushdown_requirements_to_join( - plan: &Arc, + smj: &SortMergeJoinExec, parent_required: Option<&[PhysicalSortRequirement]>, sort_expr: Vec, push_side: JoinSide, ) -> Result>>>> { - let child_idx = match push_side { - JoinSide::Left => 0, - JoinSide::Right => 1, + let left_ordering = smj.left.output_ordering().unwrap_or(&[]); + let right_ordering = smj.right.output_ordering().unwrap_or(&[]); + let (new_left_ordering, new_right_ordering) = match push_side { + JoinSide::Left => (sort_expr.as_slice(), right_ordering), + JoinSide::Right => (left_ordering, sort_expr.as_slice()), }; - let required_input_ordering = plan.required_input_ordering(); - let request_child = required_input_ordering[child_idx].as_deref(); - let child_plan = plan.children()[child_idx].clone(); - match determine_children_requirement(parent_required, request_child, child_plan) { - RequirementsCompatibility::Satisfy => Ok(None), - RequirementsCompatibility::Compatible(adjusted) => { - let new_adjusted = match push_side { - JoinSide::Left => { - vec![adjusted, required_input_ordering[1].clone()] - } - JoinSide::Right => { - vec![required_input_ordering[0].clone(), adjusted] - } - }; - Ok(Some(new_adjusted)) - } - RequirementsCompatibility::NonCompatible => { - // Can not push down, add new SortExec - add_sort_above(&mut plan.clone(), sort_expr, None)?; - Ok(None) + let new_output_ordering = calculate_join_output_ordering( + new_left_ordering, + new_right_ordering, + smj.join_type, + &smj.on, + smj.left.schema().fields.len(), + &smj.maintains_input_order(), + Some(SortMergeJoinExec::probe_side(&smj.join_type)), + )?; + Ok(ordering_satisfy_requirement( + new_output_ordering.as_deref(), + parent_required, + || smj.equivalence_properties(), + || smj.ordering_equivalence_properties(), + ) + .then(|| { + let required_input_ordering = smj.required_input_ordering(); + let new_req = Some(PhysicalSortRequirement::from_sort_exprs(&sort_expr)); + match push_side { + JoinSide::Left => { + vec![new_req, required_input_ordering[1].clone()] + } + JoinSide::Right => { + vec![required_input_ordering[0].clone(), new_req] + } } - } + })) } fn expr_source_sides( @@ -422,10 +424,9 @@ fn shift_right_required( if new_right_required.len() == parent_required.len() { Ok(new_right_required) } else { - Err(DataFusionError::Plan( + plan_err!( "Expect to shift all the parent required column indexes for SortMergeJoin" - .to_string(), - )) + ) } } diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index f35b186f0815..473d4eb131f1 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -25,11 +25,12 @@ use crate::physical_plan::{ DisplayFormatType, Distribution, EquivalenceProperties, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; + use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::utils::longest_consecutive_prefix; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::Accumulator; use datafusion_physical_expr::{ @@ -40,6 +41,7 @@ use datafusion_physical_expr::{ AggregateExpr, LexOrdering, LexOrderingReq, OrderingEquivalenceProperties, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, }; + use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -84,6 +86,20 @@ pub enum AggregateMode { SinglePartitioned, } +impl AggregateMode { + /// Checks whether this aggregation step describes a "first stage" calculation. + /// In other words, its input is not another aggregation result and the + /// `merge_batch` method will not be called for these modes. + fn is_first_stage(&self) -> bool { + match self { + AggregateMode::Partial + | AggregateMode::Single + | AggregateMode::SinglePartitioned => true, + AggregateMode::Final | AggregateMode::FinalPartitioned => false, + } + } +} + /// Group By expression modes /// /// `PartiallyOrdered` and `FullyOrdered` are used to reason about @@ -95,9 +111,6 @@ pub enum AggregateMode { /// previous combinations are guaranteed never to appear again #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum GroupByOrderMode { - /// The input is not (known to be) ordered by any of the - /// expressions in the GROUP BY clause. - None, /// The input is known to be ordered by a preset (prefix but /// possibly reordered) of the expressions in the `GROUP BY` clause. /// @@ -475,13 +488,13 @@ fn calc_required_input_ordering( }; for (is_reverse, aggregator_requirement) in aggregator_requirements.into_iter() { if let Some(AggregationOrdering { - ordering, // If the mode is FullyOrdered or PartiallyOrdered (i.e. we are // running with bounded memory, without breaking the pipeline), // then we append the aggregator ordering requirement to the existing // ordering. This way, we can still run with bounded memory. mode: GroupByOrderMode::FullyOrdered | GroupByOrderMode::PartiallyOrdered, order_indices, + .. }) = aggregation_ordering { // Get the section of the input ordering that enables us to run in @@ -495,32 +508,17 @@ fn calc_required_input_ordering( let mut requirement = PhysicalSortRequirement::from_sort_exprs(requirement_prefix.iter()); for req in aggregator_requirement { - if requirement.iter().all(|item| req.expr.ne(&item.expr)) { - requirement.push(req.clone()); - } - // In partial mode, append required ordering of the aggregator to the output ordering. - // In case of multiple partitions, this enables us to reduce partitions correctly. - if matches!(mode, AggregateMode::Partial) - && ordering.iter().all(|item| req.expr.ne(&item.expr)) + // Final and FinalPartitioned modes don't enforce ordering + // requirements since order-sensitive aggregators handle such + // requirements during merging. + if mode.is_first_stage() + && requirement.iter().all(|item| req.expr.ne(&item.expr)) { - ordering.push(req.into()); + requirement.push(req); } } required_input_ordering = requirement; - } else { - // If there was no pre-existing output ordering, the output ordering is simply the required - // ordering of the aggregator in partial mode. - if matches!(mode, AggregateMode::Partial) - && !aggregator_requirement.is_empty() - { - *aggregation_ordering = Some(AggregationOrdering { - mode: GroupByOrderMode::None, - order_indices: vec![], - ordering: PhysicalSortRequirement::to_sort_exprs( - aggregator_requirement.clone(), - ), - }); - } + } else if mode.is_first_stage() { required_input_ordering = aggregator_requirement; } // Keep track of the direction from which required_input_ordering is constructed: @@ -557,10 +555,9 @@ fn calc_required_input_ordering( *aggr_expr = reverse; *ob_expr = ob_expr.as_ref().map(|obs| reverse_order_bys(obs)); } else { - return Err(DataFusionError::Plan( + return plan_err!( "Aggregate expression should have a reverse expression" - .to_string(), - )); + ); } } Ok(()) @@ -596,12 +593,16 @@ impl AggregateExec { .iter() .zip(order_by_expr.into_iter()) .map(|(aggr_expr, fn_reqs)| { - // If aggregation function is ordering sensitive, keep ordering requirement as is; otherwise ignore requirement - if is_order_sensitive(aggr_expr) { - fn_reqs - } else { - None - } + // If the aggregation function is order-sensitive and we are + // performing a "first stage" calculation, keep the ordering + // requirement as is; otherwise ignore the ordering requirement. + // In non-first stage modes, we accumulate data (using `merge_batch`) + // from different partitions (i.e. merge partial results). During + // this merge, we consider the ordering of each partial result. + // Hence, we do not need to use the ordering requirement in such + // modes as long as partial results are generated with the + // correct ordering. + fn_reqs.filter(|_| is_order_sensitive(aggr_expr) && mode.is_first_stage()) }) .collect::>(); let mut aggregator_reverse_reqs = None; @@ -645,7 +646,6 @@ impl AggregateExec { } let mut aggregation_ordering = calc_aggregation_ordering(&input, &group_by); - let required_input_ordering = calc_required_input_ordering( &input, &mut aggr_expr, @@ -847,9 +847,9 @@ impl ExecutionPlan for AggregateExec { if children[0] { if self.aggregation_ordering.is_none() { // Cannot run without breaking pipeline. - Err(DataFusionError::Plan( - "Aggregate Error: `GROUP BY` clauses with columns without ordering and GROUPING SETS are not supported for unbounded inputs.".to_string(), - )) + plan_err!( + "Aggregate Error: `GROUP BY` clauses with columns without ordering and GROUPING SETS are not supported for unbounded inputs." + ) } else { Ok(true) } @@ -1216,14 +1216,26 @@ fn evaluate_group_by( mod tests { use super::*; use crate::execution::context::SessionConfig; + use crate::physical_plan::aggregates::GroupByOrderMode::{ + FullyOrdered, PartiallyOrdered, + }; use crate::physical_plan::aggregates::{ get_finest_requirement, get_working_mode, AggregateExec, AggregateMode, PhysicalGroupBy, }; + use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; + use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::expressions::{col, Avg}; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::{ + DisplayAs, ExecutionPlan, Partitioning, RecordBatchStream, + SendableRecordBatchStream, Statistics, + }; + use crate::prelude::SessionContext; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{assert_is_pending, csv_exec_sorted}; - use crate::{assert_batches_sorted_eq, physical_plan::common}; + use crate::{assert_batches_eq, assert_batches_sorted_eq, physical_plan::common}; + use arrow::array::{Float64Array, UInt32Array}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -1231,27 +1243,18 @@ mod tests { use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Column, Count, FirstValue, Median, + lit, ApproxDistinct, Column, Count, FirstValue, LastValue, Median, }; use datafusion_physical_expr::{ AggregateExpr, EquivalenceProperties, OrderingEquivalenceProperties, PhysicalExpr, PhysicalSortExpr, }; - use futures::{FutureExt, Stream}; + use std::any::Any; use std::sync::Arc; use std::task::{Context, Poll}; - use super::StreamType; - use crate::physical_plan::aggregates::GroupByOrderMode::{ - FullyOrdered, PartiallyOrdered, - }; - use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; - use crate::physical_plan::{ - DisplayAs, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, Statistics, - }; - use crate::prelude::SessionContext; + use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) fn create_test_schema() -> Result { @@ -1370,6 +1373,57 @@ mod tests { ) } + /// Generates some mock data for aggregate tests. + fn some_data_v2() -> (Arc, Vec) { + // Define a schema: + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::Float64, false), + ])); + + // Generate data so that first and last value results are at 2nd and + // 3rd partitions. With this construction, we guarantee we don't receive + // the expected result by accident, but merging actually works properly; + // i.e. it doesn't depend on the data insertion order. + ( + schema.clone(), + vec![ + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), + Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), + Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + schema, + vec![ + Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), + Arc::new(Float64Array::from(vec![2.0, 3.0, 4.0, 5.0])), + ], + ) + .unwrap(), + ], + ) + } + async fn check_grouping_sets(input: Arc) -> Result<()> { let input_schema = input.schema(); @@ -1885,6 +1939,134 @@ mod tests { Ok(()) } + #[tokio::test] + async fn run_first_last_multi_partitions() -> Result<()> { + for use_coalesce_batches in [false, true] { + for is_first_acc in [false, true] { + first_last_multi_partitions(use_coalesce_batches, is_first_acc).await? + } + } + Ok(()) + } + + // This function either constructs the physical plan below, + // + // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", + // " CoalesceBatchesExec: target_batch_size=1024", + // " CoalescePartitionsExec", + // " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None", + // " MemoryExec: partitions=4, partition_sizes=[1, 1, 1, 1]", + // + // or + // + // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", + // " CoalescePartitionsExec", + // " AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[FIRST_VALUE(b)], ordering_mode=None", + // " MemoryExec: partitions=4, partition_sizes=[1, 1, 1, 1]", + // + // and checks whether the function `merge_batch` works correctly for + // FIRST_VALUE and LAST_VALUE functions. + async fn first_last_multi_partitions( + use_coalesce_batches: bool, + is_first_acc: bool, + ) -> Result<()> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + let (schema, data) = some_data_v2(); + let partition1 = data[0].clone(); + let partition2 = data[1].clone(); + let partition3 = data[2].clone(); + let partition4 = data[3].clone(); + + let groups = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + + let ordering_req = vec![PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions::default(), + }]; + let aggregates: Vec> = if is_first_acc { + vec![Arc::new(FirstValue::new( + col("b", &schema)?, + "FIRST_VALUE(b)".to_string(), + DataType::Float64, + ordering_req.clone(), + vec![DataType::Float64], + ))] + } else { + vec![Arc::new(LastValue::new( + col("b", &schema)?, + "LAST_VALUE(b)".to_string(), + DataType::Float64, + ordering_req.clone(), + vec![DataType::Float64], + ))] + }; + + let memory_exec = Arc::new(MemoryExec::try_new( + &[ + vec![partition1], + vec![partition2], + vec![partition3], + vec![partition4], + ], + schema.clone(), + None, + )?); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + vec![None], + vec![Some(ordering_req.clone())], + memory_exec, + schema.clone(), + )?); + let coalesce = if use_coalesce_batches { + let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec)); + Arc::new(CoalesceBatchesExec::new(coalesce, 1024)) as Arc + } else { + Arc::new(CoalescePartitionsExec::new(aggregate_exec)) + as Arc + }; + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + groups, + aggregates.clone(), + vec![None], + vec![Some(ordering_req)], + coalesce, + schema, + )?) as Arc; + + let result = crate::physical_plan::collect(aggregate_final, task_ctx).await?; + if is_first_acc { + let expected = vec![ + "+---+----------------+", + "| a | FIRST_VALUE(b) |", + "+---+----------------+", + "| 2 | 0.0 |", + "| 3 | 1.0 |", + "| 4 | 3.0 |", + "+---+----------------+", + ]; + assert_batches_eq!(expected, &result); + } else { + let expected = vec![ + "+---+---------------+", + "| a | LAST_VALUE(b) |", + "+---+---------------+", + "| 2 | 3.0 |", + "| 3 | 5.0 |", + "| 4 | 6.0 |", + "+---+---------------+", + ]; + assert_batches_eq!(expected, &result); + }; + Ok(()) + } + #[tokio::test] async fn test_get_finest_requirements() -> Result<()> { let test_schema = create_test_schema()?; diff --git a/datafusion/core/src/physical_plan/aggregates/order/mod.rs b/datafusion/core/src/physical_plan/aggregates/order/mod.rs index ebe662c980bf..f0b49872b1c5 100644 --- a/datafusion/core/src/physical_plan/aggregates/order/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/order/mod.rs @@ -52,7 +52,6 @@ impl GroupOrdering { } = ordering; Ok(match mode { - GroupByOrderMode::None => GroupOrdering::None, GroupByOrderMode::PartiallyOrdered => { let partial = GroupOrderingPartial::try_new(input_schema, order_indices, ordering)?; diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index 19c05eaada32..46dbc9ef6204 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -23,7 +23,7 @@ use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_physical_expr::expressions::{BinaryExpr, Column}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -50,9 +50,7 @@ pub fn build_checked_file_list(dir: &str, ext: &str) -> Result> { let mut filenames: Vec = Vec::new(); build_file_list_recurse(dir, &mut filenames, ext)?; if filenames.is_empty() { - return Err(DataFusionError::Plan(format!( - "No files found at {dir} with file extension {ext}" - ))); + return plan_err!("No files found at {dir} with file extension {ext}"); } Ok(filenames) } @@ -86,7 +84,7 @@ fn build_file_list_recurse( filenames.push(path_name.to_string()); } } else { - return Err(DataFusionError::Plan("Invalid path".to_string())); + return plan_err!("Invalid path"); } } } @@ -99,24 +97,27 @@ pub(crate) fn spawn_buffered( mut input: SendableRecordBatchStream, buffer: usize, ) -> SendableRecordBatchStream { - // Use tokio only if running from a tokio context (#2201) - if tokio::runtime::Handle::try_current().is_err() { - return input; - }; - - let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer); - - let sender = builder.tx(); + // Use tokio only if running from a multi-thread tokio context + match tokio::runtime::Handle::try_current() { + Ok(handle) + if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread => + { + let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer); + + let sender = builder.tx(); + + builder.spawn(async move { + while let Some(item) = input.next().await { + if sender.send(item).await.is_err() { + return; + } + } + }); - builder.spawn(async move { - while let Some(item) = input.next().await { - if sender.send(item).await.is_err() { - return; - } + builder.build() } - }); - - builder.build() + _ => input, + } } /// Computes the statistics for an in-memory RecordBatch diff --git a/datafusion/core/src/physical_plan/filter.rs b/datafusion/core/src/physical_plan/filter.rs index 29a705daf029..e8c181a34ba2 100644 --- a/datafusion/core/src/physical_plan/filter.rs +++ b/datafusion/core/src/physical_plan/filter.rs @@ -37,7 +37,7 @@ use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::BinaryExpr; @@ -74,9 +74,9 @@ impl FilterExec { input: input.clone(), metrics: ExecutionPlanMetricsSet::new(), }), - other => Err(DataFusionError::Plan(format!( - "Filter predicate must return boolean values, not {other:?}" - ))), + other => { + plan_err!("Filter predicate must return boolean values, not {other:?}") + } } } diff --git a/datafusion/core/src/physical_plan/insert.rs b/datafusion/core/src/physical_plan/insert.rs index 8766b62e9a9e..622e33b117fd 100644 --- a/datafusion/core/src/physical_plan/insert.rs +++ b/datafusion/core/src/physical_plan/insert.rs @@ -36,7 +36,6 @@ use std::fmt::Debug; use std::sync::Arc; use crate::physical_plan::stream::RecordBatchStreamAdapter; -use crate::physical_plan::Distribution; use datafusion_common::DataFusionError; use datafusion_execution::TaskContext; @@ -57,7 +56,7 @@ pub trait DataSink: DisplayAs + Debug + Send + Sync { /// or rollback required. async fn write_all( &self, - data: SendableRecordBatchStream, + data: Vec, context: &Arc, ) -> Result; } @@ -97,7 +96,7 @@ impl InsertExec { } } - fn make_input_stream( + fn execute_input_stream( &self, partition: usize, context: Arc, @@ -136,6 +135,18 @@ impl InsertExec { ))) } } + + fn execute_all_input_streams( + &self, + context: Arc, + ) -> Result> { + let n_input_parts = self.input.output_partitioning().partition_count(); + let mut streams = Vec::with_capacity(n_input_parts); + for part in 0..n_input_parts { + streams.push(self.execute_input_stream(part, context.clone())?); + } + Ok(streams) + } } impl DisplayAs for InsertExec { @@ -172,8 +183,12 @@ impl ExecutionPlan for InsertExec { None } - fn required_input_distribution(&self) -> Vec { - vec![Distribution::SinglePartition] + fn benefits_from_input_partitioning(&self) -> bool { + // Incoming number of partitions is taken to be the + // number of files the query is required to write out. + // The optimizer should not change this number. + // Parrallelism is handled within the appropriate DataSink + false } fn required_input_ordering(&self) -> Vec>> { @@ -218,20 +233,10 @@ impl ExecutionPlan for InsertExec { ) -> Result { if partition != 0 { return Err(DataFusionError::Internal( - format!("Invalid requested partition {partition}. InsertExec requires a single input partition." - ))); + "InsertExec can only be called on partition 0!".into(), + )); } - - // Execute each of our own input's partitions and pass them to the sink - let input_partition_count = self.input.output_partitioning().partition_count(); - if input_partition_count != 1 { - return Err(DataFusionError::Internal(format!( - "Invalid input partition count {input_partition_count}. \ - InsertExec needs only a single partition." - ))); - } - - let data = self.make_input_stream(0, context.clone())?; + let data = self.execute_all_input_streams(context.clone())?; let count_schema = self.count_schema.clone(); let sink = self.sink.clone(); diff --git a/datafusion/core/src/physical_plan/joins/cross_join.rs b/datafusion/core/src/physical_plan/joins/cross_join.rs index 1ecbfbea95ca..eaee9892d0e2 100644 --- a/datafusion/core/src/physical_plan/joins/cross_join.rs +++ b/datafusion/core/src/physical_plan/joins/cross_join.rs @@ -34,7 +34,7 @@ use crate::physical_plan::{ SendableRecordBatchStream, Statistics, }; use async_trait::async_trait; -use datafusion_common::DataFusionError; +use datafusion_common::{plan_err, DataFusionError}; use datafusion_common::{Result, ScalarValue}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; @@ -176,10 +176,9 @@ impl ExecutionPlan for CrossJoinExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] || children[1] { - Err(DataFusionError::Plan( + plan_err!( "Cross Join Error: Cross join is not supported for the unbounded inputs." - .to_string(), - )) + ) } else { Ok(false) } diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index ce1d6dbcc083..2108893ccbb0 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -25,10 +25,9 @@ use std::task::Poll; use std::{any::Any, usize, vec}; use crate::physical_plan::joins::utils::{ - add_offset_to_ordering_equivalence_classes, adjust_indices_by_join_type, - apply_join_filter_to_indices, build_batch_from_indices, - calculate_hash_join_output_order, get_final_indices_from_bit_map, - need_produce_result_in_final, JoinSide, + adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, + calculate_join_output_ordering, combine_join_ordering_equivalence_properties, + get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide, }; use crate::physical_plan::DisplayAs; use crate::physical_plan::{ @@ -77,7 +76,7 @@ use arrow::{ use arrow_array::cast::downcast_array; use arrow_schema::ArrowError; use datafusion_common::cast::{as_dictionary_array, as_string_array}; -use datafusion_common::{DataFusionError, JoinType, Result}; +use datafusion_common::{plan_err, DataFusionError, JoinType, Result}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_physical_expr::OrderingEquivalenceProperties; @@ -139,9 +138,7 @@ impl HashJoinExec { let left_schema = left.schema(); let right_schema = right.schema(); if on.is_empty() { - return Err(DataFusionError::Plan( - "On constraints in HashJoinExec should be non-empty".to_string(), - )); + return plan_err!("On constraints in HashJoinExec should be non-empty"); } check_join_is_valid(&left_schema, &right_schema, &on)?; @@ -151,11 +148,14 @@ impl HashJoinExec { let random_state = RandomState::with_seeds(0, 0, 0, 0); - let output_order = calculate_hash_join_output_order( - join_type, - left.output_ordering(), - right.output_ordering(), - left.schema().fields().len(), + let output_order = calculate_join_output_ordering( + left.output_ordering().unwrap_or(&[]), + right.output_ordering().unwrap_or(&[]), + *join_type, + &on, + left_schema.fields.len(), + &Self::maintains_input_order(*join_type), + Some(Self::probe_side()), )?; Ok(HashJoinExec { @@ -209,6 +209,23 @@ impl HashJoinExec { pub fn null_equals_null(&self) -> bool { self.null_equals_null } + + /// Calculate order preservation flags for this hash join. + fn maintains_input_order(join_type: JoinType) -> Vec { + vec![ + false, + matches!( + join_type, + JoinType::Inner | JoinType::RightAnti | JoinType::RightSemi + ), + ] + } + + /// Get probe side information for the hash join. + pub fn probe_side() -> JoinSide { + // In current implementation right side is always probe side. + JoinSide::Right + } } impl DisplayAs for HashJoinExec { @@ -291,14 +308,14 @@ impl ExecutionPlan for HashJoinExec { )); if breaking { - Err(DataFusionError::Plan(format!( + plan_err!( "Join Error: The join with cannot be executed with unbounded inputs. {}", if left && right { "Currently, we do not support unbounded inputs on both sides." } else { "Please consider a different type of join or sources." } - ))) + ) } else { Ok(left || right) } @@ -355,13 +372,7 @@ impl ExecutionPlan for HashJoinExec { // are processed sequentially in the probe phase, and unmatched rows are directly output // as results, these results tend to retain the order of the probe side table. fn maintains_input_order(&self) -> Vec { - vec![ - false, - matches!( - self.join_type, - JoinType::Inner | JoinType::RightAnti | JoinType::RightSemi - ), - ] + Self::maintains_input_order(self.join_type) } fn equivalence_properties(&self) -> EquivalenceProperties { @@ -377,31 +388,16 @@ impl ExecutionPlan for HashJoinExec { } fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - let mut new_properties = OrderingEquivalenceProperties::new(self.schema()); - let left_columns_len = self.left.schema().fields.len(); - let right_oeq_properties = self.right.ordering_equivalence_properties(); - match self.join_type { - JoinType::RightAnti | JoinType::RightSemi => { - // For `RightAnti` and `RightSemi` joins, the right table schema remains valid. - // Hence, its ordering equivalence properties can be used as is. - new_properties.extend(right_oeq_properties.classes().iter().cloned()); - } - JoinType::Inner => { - // For `Inner` joins, the right table schema is no longer valid. - // Size of the left table is added as an offset to the right table - // columns when constructing the join output schema. - let updated_right_classes = add_offset_to_ordering_equivalence_classes( - right_oeq_properties.classes(), - left_columns_len, - ) - .unwrap(); - new_properties.extend(updated_right_classes); - } - // In other cases, we cannot propagate ordering equivalences as - // the output ordering is not preserved. - _ => {} - } - new_properties + combine_join_ordering_equivalence_properties( + &self.join_type, + &self.left, + &self.right, + self.schema(), + &self.maintains_input_order(), + Some(Self::probe_side()), + self.equivalence_properties(), + ) + .unwrap() } fn children(&self) -> Vec> { @@ -471,10 +467,10 @@ impl ExecutionPlan for HashJoinExec { )) } PartitionMode::Auto => { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Invalid HashJoinExec, unsupported PartitionMode {:?} in execute()", PartitionMode::Auto - ))); + ); } }; @@ -718,14 +714,14 @@ impl RecordBatchStream for HashJoinStream { // "+----+----+-----+----+----+-----+", // "| a1 | b1 | c1 | a2 | b2 | c2 |", // "+----+----+-----+----+----+-----+", +// "| 9 | 8 | 90 | 8 | 8 | 80 |", // "| 11 | 8 | 110 | 8 | 8 | 80 |", // "| 13 | 10 | 130 | 10 | 10 | 100 |", // "| 13 | 10 | 130 | 12 | 10 | 120 |", -// "| 9 | 8 | 90 | 8 | 8 | 80 |", // "+----+----+-----+----+----+-----+" // And the result of build and probe indices are: -// Build indices: 5, 6, 6, 4 -// Probe indices: 3, 4, 5, 3 +// Build indices: 4, 5, 6, 6 +// Probe indices: 3, 3, 4, 5 #[allow(clippy::too_many_arguments)] pub fn build_equal_condition_join_indices( build_hashmap: &JoinHashMap, @@ -756,8 +752,36 @@ pub fn build_equal_condition_join_indices( // Using a buffer builder to avoid slower normal builder let mut build_indices = UInt64BufferBuilder::new(0); let mut probe_indices = UInt32BufferBuilder::new(0); - // Visit all of the probe rows - for (row, hash_value) in hash_values.iter().enumerate() { + // The chained list algorithm generates build indices for each probe row in a reversed sequence as such: + // Build Indices: [5, 4, 3] + // Probe Indices: [1, 1, 1] + // + // This affects the output sequence. Hypothetically, it's possible to preserve the lexicographic order on the build side. + // Let's consider probe rows [0,1] as an example: + // + // When the probe iteration sequence is reversed, the following pairings can be derived: + // + // For probe row 1: + // (5, 1) + // (4, 1) + // (3, 1) + // + // For probe row 0: + // (5, 0) + // (4, 0) + // (3, 0) + // + // After reversing both sets of indices, we obtain reversed indices: + // + // (3,0) + // (4,0) + // (5,0) + // (3,1) + // (4,1) + // (5,1) + // + // With this approach, the lexicographic order on both the probe side and the build side is preserved. + for (row, hash_value) in hash_values.iter().enumerate().rev() { // Get the hash and find it in the build index // For every item on the build and probe we check if it matches @@ -781,6 +805,9 @@ pub fn build_equal_condition_join_indices( } } } + // Reversing both sets of indices + build_indices.as_slice_mut().reverse(); + probe_indices.as_slice_mut().reverse(); let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None); let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None); diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs index ed63fb844897..cad3b4743bc9 100644 --- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs +++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs @@ -73,6 +73,7 @@ type JoinLeftData = (RecordBatch, MemoryReservation); /// |--------------------------------|--------------------------------------------|-------------| /// | Inner/Left/LeftSemi/LeftAnti | (UnspecifiedDistribution, SinglePartition) | right | /// | Right/RightSemi/RightAnti/Full | (SinglePartition, UnspecifiedDistribution) | left | +/// | Full | (SinglePartition, SinglePartition) | left | /// #[derive(Debug)] pub struct NestedLoopJoinExec { @@ -119,12 +120,12 @@ impl NestedLoopJoinExec { }) } - /// left (build) side which gets hashed + /// left side pub fn left(&self) -> &Arc { &self.left } - /// right (probe) side which are filtered by the hash table + /// right side pub fn right(&self) -> &Arc { &self.right } diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs index 9b8e9e85cd4f..b3721eb4d616 100644 --- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs @@ -33,9 +33,9 @@ use std::task::{Context, Poll}; use crate::physical_plan::expressions::Column; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::joins::utils::{ - add_offset_to_lex_ordering, add_offset_to_ordering_equivalence_classes, - build_join_schema, check_join_is_valid, combine_join_equivalence_properties, - estimate_join_statistics, partitioned_join_output_partitioning, JoinOn, + build_join_schema, calculate_join_output_ordering, check_join_is_valid, + combine_join_equivalence_properties, combine_join_ordering_equivalence_properties, + estimate_join_statistics, partitioned_join_output_partitioning, JoinOn, JoinSide, }; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use crate::physical_plan::{ @@ -49,14 +49,12 @@ use arrow::compute::{concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, JoinType, Result}; +use datafusion_common::{plan_err, DataFusionError, JoinType, Result}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::utils::normalize_sort_exprs; use datafusion_physical_expr::{OrderingEquivalenceProperties, PhysicalSortRequirement}; use futures::{Stream, StreamExt}; -use itertools::Itertools; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -86,38 +84,6 @@ pub struct SortMergeJoinExec { pub(crate) null_equals_null: bool, } -/// Replaces the right column (first index in the `on_column` tuple) with -/// the left column (zeroth index in the tuple) inside `right_ordering`. -fn replace_on_columns_of_right_ordering( - on_columns: &[(Column, Column)], - right_ordering: &mut [PhysicalSortExpr], - left_columns_len: usize, -) { - for (left_col, right_col) in on_columns { - let right_col = - Column::new(right_col.name(), right_col.index() + left_columns_len); - for item in right_ordering.iter_mut() { - if let Some(col) = item.expr.as_any().downcast_ref::() { - if right_col.eq(col) { - item.expr = Arc::new(left_col.clone()) as _; - } - } - } - } -} - -/// Merge left and right sort expressions, checking for duplicates. -fn merge_vectors( - left: &[PhysicalSortExpr], - right: &[PhysicalSortExpr], -) -> Vec { - left.iter() - .cloned() - .chain(right.iter().cloned()) - .unique() - .collect() -} - impl SortMergeJoinExec { /// Tries to create a new [SortMergeJoinExec]. /// The inputs are sorted using `sort_options` are applied to the columns in the `on` @@ -142,11 +108,11 @@ impl SortMergeJoinExec { check_join_is_valid(&left_schema, &right_schema, &on)?; if sort_options.len() != on.len() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Expected number of sort options: {}, actual: {}", on.len(), sort_options.len() - ))); + ); } let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on @@ -165,43 +131,15 @@ impl SortMergeJoinExec { }) .unzip(); - let output_ordering = match join_type { - JoinType::Inner => { - match (left.output_ordering(), right.output_ordering()) { - // If both sides have orderings, ordering of the right hand side - // can be appended to the left side ordering for inner joins. - (Some(left_ordering), Some(right_ordering)) => { - let left_columns_len = left.schema().fields.len(); - let mut right_ordering = - add_offset_to_lex_ordering(right_ordering, left_columns_len)?; - replace_on_columns_of_right_ordering( - &on, - &mut right_ordering, - left_columns_len, - ); - Some(merge_vectors(left_ordering, &right_ordering)) - } - (Some(left_ordering), _) => Some(left_ordering.to_vec()), - _ => None, - } - } - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - left.output_ordering().map(|sort_exprs| sort_exprs.to_vec()) - } - JoinType::RightSemi | JoinType::RightAnti => right - .output_ordering() - .map(|sort_exprs| sort_exprs.to_vec()), - JoinType::Right => { - let left_columns_len = left.schema().fields.len(); - right - .output_ordering() - .map(|sort_exprs| { - add_offset_to_lex_ordering(sort_exprs, left_columns_len) - }) - .map_or(Ok(None), |v| v.map(Some))? - } - JoinType::Full => None, - }; + let output_ordering = calculate_join_output_ordering( + left.output_ordering().unwrap_or(&[]), + right.output_ordering().unwrap_or(&[]), + join_type, + &on, + left_schema.fields.len(), + &Self::maintains_input_order(join_type), + Some(Self::probe_side(&join_type)), + )?; let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); @@ -221,6 +159,35 @@ impl SortMergeJoinExec { }) } + /// Get probe side (e.g streaming side) information for this sort merge join. + /// In current implementation, probe side is determined according to join type. + pub fn probe_side(join_type: &JoinType) -> JoinSide { + // When output schema contains only the right side, probe side is right. + // Otherwise probe side is the left side. + match join_type { + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinSide::Right + } + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::LeftAnti + | JoinType::LeftSemi => JoinSide::Left, + } + } + + /// Calculate order preservation flags for this sort merge join. + fn maintains_input_order(join_type: JoinType) -> Vec { + match join_type { + JoinType::Inner => vec![true, false], + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false], + JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + vec![false, true] + } + _ => vec![false, false], + } + } + /// Set of common columns used to join on pub fn on(&self) -> &[(Column, Column)] { &self.on @@ -299,14 +266,7 @@ impl ExecutionPlan for SortMergeJoinExec { } fn maintains_input_order(&self) -> Vec { - match self.join_type { - JoinType::Inner => vec![true, true], - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false], - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - vec![false, true] - } - _ => vec![false, false], - } + Self::maintains_input_order(self.join_type) } fn equivalence_properties(&self) -> EquivalenceProperties { @@ -322,62 +282,16 @@ impl ExecutionPlan for SortMergeJoinExec { } fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { - let mut new_properties = OrderingEquivalenceProperties::new(self.schema()); - let left_columns_len = self.left.schema().fields.len(); - let left_oeq_properties = self.left.ordering_equivalence_properties(); - let right_oeq_properties = self.right.ordering_equivalence_properties(); - match self.join_type { - JoinType::Inner => { - // Since left side is the stream side for this `SortMergeJoin` implementation, - // global ordering of the left table is preserved at the output. Hence, left - // side ordering equivalences are still valid. - new_properties.extend(left_oeq_properties.classes().iter().cloned()); - if let Some(output_ordering) = &self.output_ordering { - // Update right table ordering equivalence expression indices; i.e. - // add left table size as an offset. - let updated_right_oeq_classes = - add_offset_to_ordering_equivalence_classes( - right_oeq_properties.classes(), - left_columns_len, - ) - .unwrap(); - let left_output_ordering = self.left.output_ordering().unwrap_or(&[]); - // Right side ordering equivalence properties should be prepended with - // those of the left side while constructing output ordering equivalence - // properties for `SortMergeJoin`. As an example; - // - // If the right table ordering equivalences contain `b ASC`, and the output - // ordering of the left table is `a ASC`, then the ordering equivalence `b ASC` - // for the right table should be converted to `a ASC, b ASC` before it is added - // to the ordering equivalences of `SortMergeJoinExec`. - for oeq_class in updated_right_oeq_classes { - for ordering in oeq_class.others() { - // Entries inside ordering equivalence should be normalized before insertion. - let normalized_ordering = normalize_sort_exprs( - ordering, - self.equivalence_properties().classes(), - &[], - ); - let new_oeq_ordering = - merge_vectors(left_output_ordering, &normalized_ordering); - new_properties.add_equal_conditions(( - output_ordering, - &new_oeq_ordering, - )); - } - } - } - } - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - new_properties.extend(left_oeq_properties.classes().iter().cloned()); - } - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - new_properties.extend(right_oeq_properties.classes().iter().cloned()); - } - // All ordering equivalences from left and/or right sides are invalidated. - _ => {} - } - new_properties + combine_join_ordering_equivalence_properties( + &self.join_type, + &self.left, + &self.right, + self.schema(), + &self.maintains_input_order(), + Some(Self::probe_side(&self.join_type)), + self.equivalence_properties(), + ) + .unwrap() } fn children(&self) -> Vec> { @@ -416,25 +330,13 @@ impl ExecutionPlan for SortMergeJoinExec { consider using RepartitionExec", ))); } - - let (streamed, buffered, on_streamed, on_buffered) = match self.join_type { - JoinType::Inner - | JoinType::Left - | JoinType::Full - | JoinType::LeftAnti - | JoinType::LeftSemi => ( - self.left.clone(), - self.right.clone(), - self.on.iter().map(|on| on.0.clone()).collect(), - self.on.iter().map(|on| on.1.clone()).collect(), - ), - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => ( - self.right.clone(), - self.left.clone(), - self.on.iter().map(|on| on.1.clone()).collect(), - self.on.iter().map(|on| on.0.clone()).collect(), - ), - }; + let (on_left, on_right) = self.on.iter().cloned().unzip(); + let (streamed, buffered, on_streamed, on_buffered) = + if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left { + (self.left.clone(), self.right.clone(), on_left, on_right) + } else { + (self.right.clone(), self.left.clone(), on_right, on_left) + }; // execute children plans let streamed = streamed.execute(partition, context.clone())?; diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 1818e4b91c1b..dc8bcc2edb26 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -75,7 +75,7 @@ use crate::physical_plan::{ RecordBatchStream, SendableRecordBatchStream, Statistics, }; use datafusion_common::utils::bisect; -use datafusion_common::JoinType; +use datafusion_common::{plan_err, JoinType}; use datafusion_common::{DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -278,9 +278,9 @@ impl SymmetricHashJoinExec { // Error out if no "on" contraints are given: if on.is_empty() { - return Err(DataFusionError::Plan( - "On constraints in SymmetricHashJoinExec should be non-empty".to_string(), - )); + return plan_err!( + "On constraints in SymmetricHashJoinExec should be non-empty" + ); } // Check if the join is valid with the given on constraints: diff --git a/datafusion/core/src/physical_plan/joins/utils.rs b/datafusion/core/src/physical_plan/joins/utils.rs index ea9c11a3ed98..abba191f047b 100644 --- a/datafusion/core/src/physical_plan/joins/utils.rs +++ b/datafusion/core/src/physical_plan/joins/utils.rs @@ -40,15 +40,20 @@ use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DataFusionError, JoinType, Result, ScalarValue, SharedResult}; +use datafusion_common::{ + plan_err, DataFusionError, JoinType, Result, ScalarValue, SharedResult, +}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ - EquivalentClass, LexOrdering, LexOrderingRef, OrderingEquivalentClass, PhysicalExpr, - PhysicalSortExpr, + EquivalentClass, LexOrdering, LexOrderingRef, OrderingEquivalenceProperties, + OrderingEquivalentClass, PhysicalExpr, PhysicalSortExpr, }; +use datafusion_physical_expr::utils::normalize_sort_exprs; + use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; +use itertools::Itertools; use parking_lot::Mutex; /// The on clause of the join, as vector of (left, right) columns. @@ -89,9 +94,9 @@ fn check_join_set_is_valid( let right_missing = on_right.difference(right).collect::>(); if !left_missing.is_empty() | !right_missing.is_empty() { - return Err(DataFusionError::Plan(format!( - "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}", - ))); + return plan_err!( + "The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {left_missing:?}\nMissing on the right: {right_missing:?}" + ); }; Ok(()) @@ -147,64 +152,93 @@ pub fn adjust_right_output_partitioning( } } -fn adjust_right_order( - right_order: &[PhysicalSortExpr], +/// Replaces the right column (first index in the `on_column` tuple) with +/// the left column (zeroth index in the tuple) inside `right_ordering`. +fn replace_on_columns_of_right_ordering( + on_columns: &[(Column, Column)], + right_ordering: &mut [PhysicalSortExpr], left_columns_len: usize, -) -> Result> { - right_order - .iter() - .map(|sort_expr| { - let expr = sort_expr.expr.clone(); - let adjusted = expr.transform_up(&|expr| { - Ok( - if let Some(column) = expr.as_any().downcast_ref::() { - let new_col = - Column::new(column.name(), column.index() + left_columns_len); - Transformed::Yes(Arc::new(new_col)) - } else { - Transformed::No(expr) - }, - ) - })?; - Ok(PhysicalSortExpr { - expr: adjusted, - options: sort_expr.options, - }) - }) - .collect::>>() -} - -/// Calculate the output order for hash join. -pub fn calculate_hash_join_output_order( - join_type: &JoinType, - maybe_left_order: Option<&[PhysicalSortExpr]>, - maybe_right_order: Option<&[PhysicalSortExpr]>, - left_len: usize, -) -> Result>> { - match maybe_right_order { - Some(right_order) => { - let result = match join_type { - JoinType::Inner => { - // We modify the indices of the right order columns because their - // columns are appended to the right side of the left schema. - let mut adjusted_right_order = - adjust_right_order(right_order, left_len)?; - if let Some(left_order) = maybe_left_order { - adjusted_right_order.extend_from_slice(left_order); - } - Some(adjusted_right_order) +) { + for (left_col, right_col) in on_columns { + let right_col = + Column::new(right_col.name(), right_col.index() + left_columns_len); + for item in right_ordering.iter_mut() { + if let Some(col) = item.expr.as_any().downcast_ref::() { + if right_col.eq(col) { + item.expr = Arc::new(left_col.clone()) as _; } - JoinType::RightAnti | JoinType::RightSemi => Some(right_order.to_vec()), - _ => None, - }; - - Ok(result) + } } - None => Ok(None), } } -/// Combine the Equivalence Properties for Join Node +/// Calculate the output ordering of a given join operation. +pub fn calculate_join_output_ordering( + left_ordering: LexOrderingRef, + right_ordering: LexOrderingRef, + join_type: JoinType, + on_columns: &[(Column, Column)], + left_columns_len: usize, + maintains_input_order: &[bool], + probe_side: Option, +) -> Result> { + // All joins have 2 children: + assert_eq!(maintains_input_order.len(), 2); + let left_maintains = maintains_input_order[0]; + let right_maintains = maintains_input_order[1]; + let (mut right_ordering, on_columns) = match join_type { + // In the case below, right ordering should be offseted with the left + // side length, since we append the right table to the left table. + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + let updated_on_columns = on_columns + .iter() + .map(|(left, right)| { + ( + left.clone(), + Column::new(right.name(), right.index() + left_columns_len), + ) + }) + .collect::>(); + let updated_right_ordering = + add_offset_to_lex_ordering(right_ordering, left_columns_len)?; + (updated_right_ordering, updated_on_columns) + } + _ => (right_ordering.to_vec(), on_columns.to_vec()), + }; + let output_ordering = match (left_maintains, right_maintains) { + (true, true) => { + return Err(DataFusionError::Execution( + "Cannot maintain ordering of both sides".to_string(), + )) + } + (true, false) => { + // Special case, we can prefix ordering of right side with the ordering of left side. + if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { + replace_on_columns_of_right_ordering( + &on_columns, + &mut right_ordering, + left_columns_len, + ); + merge_vectors(left_ordering, &right_ordering) + } else { + left_ordering.to_vec() + } + } + (false, true) => { + // Special case, we can prefix ordering of left side with the ordering of right side. + if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) { + merge_vectors(&right_ordering, left_ordering) + } else { + right_ordering + } + } + // Doesn't maintain ordering, output ordering is None. + (false, false) => return Ok(None), + }; + Ok((!output_ordering.is_empty()).then_some(output_ordering)) +} + +/// Combine equivalence properties of the given join inputs. pub fn combine_join_equivalence_properties( join_type: JoinType, left_properties: EquivalenceProperties, @@ -256,7 +290,7 @@ pub fn combine_join_equivalence_properties( new_properties } -/// Calculate the Equivalence Properties for CrossJoin Node +/// Calculate equivalence properties for the given cross join operation. pub fn cross_join_equivalence_properties( left_properties: EquivalenceProperties, right_properties: EquivalenceProperties, @@ -283,6 +317,155 @@ pub fn cross_join_equivalence_properties( new_properties } +/// Update right table ordering equivalences so that they point to valid indices +/// at the output of the join schema. To do so, we increment column indices by left table size +/// when join schema consist of combination of left and right schema (Inner, Left, Full, Right joins). +fn get_updated_right_ordering_equivalence_properties( + join_type: &JoinType, + right_oeq_classes: &[OrderingEquivalentClass], + left_columns_len: usize, +) -> Result> { + match join_type { + // In these modes, indices of the right schema should be offset by + // the left table size. + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + add_offset_to_ordering_equivalence_classes( + right_oeq_classes, + left_columns_len, + ) + } + _ => Ok(right_oeq_classes.to_vec()), + } +} + +/// Merge left and right sort expressions, checking for duplicates. +fn merge_vectors( + left: &[PhysicalSortExpr], + right: &[PhysicalSortExpr], +) -> Vec { + left.iter() + .cloned() + .chain(right.iter().cloned()) + .unique() + .collect() +} + +/// Prefix with existing ordering. +fn prefix_ordering_equivalence_with_existing_ordering( + existing_ordering: &[PhysicalSortExpr], + oeq_classes: &[OrderingEquivalentClass], + eq_classes: &[EquivalentClass], +) -> Vec { + oeq_classes + .iter() + .map(|oeq_class| { + let normalized_head = normalize_sort_exprs(oeq_class.head(), eq_classes, &[]); + let updated_head = merge_vectors(existing_ordering, &normalized_head); + let updated_others = oeq_class + .others() + .iter() + .map(|ordering| { + let normalized_ordering = + normalize_sort_exprs(ordering, eq_classes, &[]); + merge_vectors(existing_ordering, &normalized_ordering) + }) + .collect(); + OrderingEquivalentClass::new(updated_head, updated_others) + }) + .collect() +} + +/// Calculate ordering equivalence properties for the given join operation. +pub fn combine_join_ordering_equivalence_properties( + join_type: &JoinType, + left: &Arc, + right: &Arc, + schema: SchemaRef, + maintains_input_order: &[bool], + probe_side: Option, + join_eq_properties: EquivalenceProperties, +) -> Result { + let mut new_properties = OrderingEquivalenceProperties::new(schema); + let left_columns_len = left.schema().fields.len(); + let left_oeq_properties = left.ordering_equivalence_properties(); + let right_oeq_properties = right.ordering_equivalence_properties(); + // All joins have 2 children + assert_eq!(maintains_input_order.len(), 2); + let left_maintains = maintains_input_order[0]; + let right_maintains = maintains_input_order[1]; + match (left_maintains, right_maintains) { + (true, true) => { + return Err(DataFusionError::Plan( + "Cannot maintain ordering of both sides".to_string(), + )) + } + (true, false) => { + new_properties.extend(left_oeq_properties.classes().iter().cloned()); + // In this special case, right side ordering can be prefixed with left side ordering. + if probe_side == Some(JoinSide::Left) + && right.output_ordering().is_some() + && *join_type == JoinType::Inner + { + let right_oeq_classes = + get_updated_right_ordering_equivalence_properties( + join_type, + right_oeq_properties.classes(), + left_columns_len, + )?; + let left_output_ordering = left.output_ordering().unwrap_or(&[]); + // Right side ordering equivalence properties should be prepended with + // those of the left side while constructing output ordering equivalence + // properties since stream side is the left side. + // + // If the right table ordering equivalences contain `b ASC`, and the output + // ordering of the left table is `a ASC`, then the ordering equivalence `b ASC` + // for the right table should be converted to `a ASC, b ASC` before it is added + // to the ordering equivalences of the join. + let updated_right_oeq_classes = + prefix_ordering_equivalence_with_existing_ordering( + left_output_ordering, + &right_oeq_classes, + join_eq_properties.classes(), + ); + new_properties.extend(updated_right_oeq_classes); + } + } + (false, true) => { + let right_oeq_classes = get_updated_right_ordering_equivalence_properties( + join_type, + right_oeq_properties.classes(), + left_columns_len, + )?; + new_properties.extend(right_oeq_classes); + // In this special case, left side ordering can be prefixed with right side ordering. + if probe_side == Some(JoinSide::Right) + && left.output_ordering().is_some() + && *join_type == JoinType::Inner + { + let left_oeq_classes = right_oeq_properties.classes(); + let right_output_ordering = right.output_ordering().unwrap_or(&[]); + // Left side ordering equivalence properties should be prepended with + // those of the right side while constructing output ordering equivalence + // properties since stream side is the right side. + // + // If the right table ordering equivalences contain `b ASC`, and the output + // ordering of the left table is `a ASC`, then the ordering equivalence `b ASC` + // for the right table should be converted to `a ASC, b ASC` before it is added + // to the ordering equivalences of the join. + let updated_left_oeq_classes = + prefix_ordering_equivalence_with_existing_ordering( + right_output_ordering, + left_oeq_classes, + join_eq_properties.classes(), + ); + new_properties.extend(updated_left_oeq_classes); + } + } + (false, false) => {} + } + Ok(new_properties) +} + /// Adds the `offset` value to `Column` indices inside `expr`. This function is /// generally used during the update of the right table schema in join operations. pub(crate) fn add_offset_to_expr( diff --git a/datafusion/core/src/physical_plan/memory.rs b/datafusion/core/src/physical_plan/memory.rs index fbf7f46dd50d..5e7917b978e0 100644 --- a/datafusion/core/src/physical_plan/memory.rs +++ b/datafusion/core/src/physical_plan/memory.rs @@ -68,13 +68,23 @@ impl DisplayAs for MemoryExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let partitions: Vec<_> = + let partition_sizes: Vec<_> = self.partitions.iter().map(|b| b.len()).collect(); + + let output_ordering = self + .sort_information + .as_ref() + .map(|output_ordering| { + let order_strings: Vec<_> = + output_ordering.iter().map(|e| e.to_string()).collect(); + format!(", output_ordering={}", order_strings.join(",")) + }) + .unwrap_or_else(|| "".to_string()); + write!( f, - "MemoryExec: partitions={}, partition_sizes={:?}", - partitions.len(), - partitions + "MemoryExec: partitions={}, partition_sizes={partition_sizes:?}{output_ordering}", + partition_sizes.len(), ) } } diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs index 0d322bfb11e2..66254ee6f5f8 100644 --- a/datafusion/core/src/physical_plan/mod.rs +++ b/datafusion/core/src/physical_plan/mod.rs @@ -22,6 +22,7 @@ use self::metrics::MetricsSet; use self::{ coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, }; +use crate::datasource::physical_plan::FileScanConfig; use crate::physical_plan::expressions::PhysicalSortExpr; use datafusion_common::Result; pub use datafusion_common::{ColumnStatistics, Statistics}; @@ -54,7 +55,7 @@ pub trait RecordBatchStream: Stream> { fn schema(&self) -> SchemaRef; } -/// Trait for a [`Stream`] of [`RecordBatch`]es +/// Trait for a [`Stream`](futures::stream::Stream) of [`RecordBatch`]es pub type SendableRecordBatchStream = Pin>; /// EmptyRecordBatchStream can be used to create a RecordBatchStream @@ -226,6 +227,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// Returns the global output statistics for this `ExecutionPlan` node. fn statistics(&self) -> Statistics; + + /// Returns the [`FileScanConfig`] in case this is a data source scanning execution plan or `None` otherwise. + fn file_scan_config(&self) -> Option<&FileScanConfig> { + None + } } /// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index f660f0acf89a..c7ae09bb2e34 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -35,7 +35,7 @@ use arrow::compute::{concat_batches, lexsort_to_indices, take}; use arrow::datatypes::SchemaRef; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::{ human_readable_size, MemoryConsumer, MemoryReservation, }; @@ -76,27 +76,147 @@ impl ExternalSorterMetrics { } } -/// Sort arbitrary size of data to get a total order (may spill several times during sorting based on free memory available). +/// Sorts an arbitrary sized, unsorted, stream of [`RecordBatch`]es to +/// a total order. Depending on the input size and memory manager +/// configuration, writes intermediate results to disk ("spills") +/// using Arrow IPC format. +/// +/// # Algorithm /// -/// The basic architecture of the algorithm: /// 1. get a non-empty new batch from input -/// 2. check with the memory manager if we could buffer the batch in memory -/// 2.1 if memory sufficient, then buffer batch in memory, go to 1. -/// 2.2 if the memory threshold is reached, sort all buffered batches and spill to file. -/// buffer the batch in memory, go to 1. -/// 3. when input is exhausted, merge all in memory batches and spills to get a total order. +/// +/// 2. check with the memory manager there is sufficient space to +/// buffer the batch in memory 2.1 if memory sufficient, buffer +/// batch in memory, go to 1. +/// +/// 2.2 if no more memory is available, sort all buffered batches and +/// spill to file. buffer the next batch in memory, go to 1. +/// +/// 3. when input is exhausted, merge all in memory batches and spills +/// to get a total order. +/// +/// # When data fits in available memory +/// +/// If there is sufficient memory, data is sorted in memory to produce the output +/// +/// ```text +/// ┌─────┐ +/// │ 2 │ +/// │ 3 │ +/// │ 1 │─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// │ 4 │ +/// │ 2 │ │ +/// └─────┘ ▼ +/// ┌─────┐ +/// │ 1 │ In memory +/// │ 4 │─ ─ ─ ─ ─ ─▶ sort/merge ─ ─ ─ ─ ─▶ total sorted output +/// │ 1 │ +/// └─────┘ ▲ +/// ... │ +/// +/// ┌─────┐ │ +/// │ 4 │ +/// │ 3 │─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// └─────┘ +/// +/// in_mem_batches +/// +/// ``` +/// +/// # When data does not fit in available memory +/// +/// When memory is exhausted, data is first sorted and written to one +/// or more spill files on disk: +/// +/// ```text +/// ┌─────┐ .─────────────────. +/// │ 2 │ ( ) +/// │ 3 │ │`─────────────────'│ +/// │ 1 │─ ─ ─ ─ ─ ─ ─ │ ┌────┐ │ +/// │ 4 │ │ │ │ 1 │░ │ +/// │ 2 │ │ │... │░ │ +/// └─────┘ ▼ │ │ 4 │░ ┌ ─ ─ │ +/// ┌─────┐ │ └────┘░ 1 │░ │ +/// │ 1 │ In memory │ ░░░░░░ │ ░░ │ +/// │ 4 │─ ─ ▶ sort/merge ─ ─ ─ ─ ┼ ─ ─ ─ ─ ─▶ ... │░ │ +/// │ 1 │ and write to file │ │ ░░ │ +/// └─────┘ │ 4 │░ │ +/// ... ▲ │ └░─░─░░ │ +/// │ │ ░░░░░░ │ +/// ┌─────┐ │.─────────────────.│ +/// │ 4 │ │ ( ) +/// │ 3 │─ ─ ─ ─ ─ ─ ─ `─────────────────' +/// └─────┘ +/// +/// in_mem_batches spills +/// (file on disk in Arrow +/// IPC format) +/// ``` +/// +/// Once the input is completely read, the spill files are read and +/// merged with any in memory batches to produce a single total sorted +/// output: +/// +/// ```text +/// .─────────────────. +/// ( ) +/// │`─────────────────'│ +/// │ ┌────┐ │ +/// │ │ 1 │░ │ +/// │ │... │─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ +/// │ │ 4 │░ ┌────┐ │ │ +/// │ └────┘░ │ 1 │░ │ ▼ +/// │ ░░░░░░ │ │░ │ +/// │ │... │─ ─│─ ─ ─ ▶ merge ─ ─ ─▶ total sorted output +/// │ │ │░ │ +/// │ │ 4 │░ │ ▲ +/// │ └────┘░ │ │ +/// │ ░░░░░░ │ +/// │.─────────────────.│ │ +/// ( ) +/// `─────────────────' │ +/// spills +/// │ +/// +/// │ +/// +/// ┌─────┐ │ +/// │ 1 │ +/// │ 4 │─ ─ ─ ─ │ +/// └─────┘ │ +/// ... In memory +/// └ ─ ─ ─▶ sort/merge +/// ┌─────┐ +/// │ 4 │ ▲ +/// │ 3 │─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// └─────┘ +/// +/// in_mem_batches +/// ``` struct ExternalSorter { + /// schema of the output (and the input) schema: SchemaRef, + /// Potentially unsorted in memory buffer in_mem_batches: Vec, + /// if `Self::in_mem_batches` are sorted in_mem_batches_sorted: bool, + /// If data has previously been spilled, the locations of the + /// spill files (in Arrow IPC format) spills: Vec, /// Sort expressions expr: Arc<[PhysicalSortExpr]>, + /// Runtime metrics metrics: ExternalSorterMetrics, + /// If Some, the maximum number of output rows that will be + /// produced. fetch: Option, + /// Memory usage tracking reservation: MemoryReservation, + /// The partition id that this Sort is handling (for identification) partition_id: usize, + /// A handle to the runtime to get Disk spill files runtime: Arc, + /// The target number of rows for output batches batch_size: usize, } @@ -142,7 +262,7 @@ impl ExternalSorter { if self.reservation.try_grow(size).is_err() { let before = self.reservation.size(); self.in_mem_sort().await?; - // Sorting may have freed memory, especially if fetch is not `None` + // Sorting may have freed memory, especially if fetch is `Some` // // As such we check again, and if the memory usage has dropped by // a factor of 2, and we can allocate the necessary capacity, @@ -168,7 +288,15 @@ impl ExternalSorter { !self.spills.is_empty() } - /// MergeSort in mem batches as well as spills into total order with `SortPreservingMergeStream`. + /// Returns the final sorted output of all batches inserted via + /// [`Self::insert_batch`] as a stream of [`RecordBatch`]es. + /// + /// This process could either be: + /// + /// 1. An in-memory sort/merge (if the input fit in memory) + /// + /// 2. A combined streaming merge incorporating both in-memory + /// batches and data from spill files on disk. fn sort(&mut self) -> Result { if self.spilled_before() { let mut streams = vec![]; @@ -201,18 +329,25 @@ impl ExternalSorter { } } + /// How much memory is buffered in this `ExternalSorter`? fn used(&self) -> usize { self.reservation.size() } + /// How many bytes have been spilled to disk? fn spilled_bytes(&self) -> usize { self.metrics.spilled_bytes.value() } + /// How many spill files have been created? fn spill_count(&self) -> usize { self.metrics.spill_count.value() } + /// Writes any `in_memory_batches` to a spill file and clears + /// the batches. The contents of the spil file are sorted. + /// + /// Returns the amount of memory freed. async fn spill(&mut self) -> Result { // we could always get a chance to free some memory as long as we are holding some if self.in_mem_batches.is_empty() { @@ -255,7 +390,64 @@ impl ExternalSorter { Ok(()) } - /// Consumes in_mem_batches returning a sorted stream + /// Consumes in_mem_batches returning a sorted stream of + /// batches. This proceeds in one of two ways: + /// + /// # Small Datasets + /// + /// For "smaller" datasets, the data is first concatenated into a + /// single batch and then sorted. This is often faster than + /// sorting and then merging. + /// + /// ```text + /// ┌─────┐ + /// │ 2 │ + /// │ 3 │ + /// │ 1 │─ ─ ─ ─ ┐ ┌─────┐ + /// │ 4 │ │ 2 │ + /// │ 2 │ │ │ 3 │ + /// └─────┘ │ 1 │ sorted output + /// ┌─────┐ ▼ │ 4 │ stream + /// │ 1 │ │ 2 │ + /// │ 4 │─ ─▶ concat ─ ─ ─ ─ ▶│ 1 │─ ─ ▶ sort ─ ─ ─ ─ ─▶ + /// │ 1 │ │ 4 │ + /// └─────┘ ▲ │ 1 │ + /// ... │ │ ... │ + /// │ 4 │ + /// ┌─────┐ │ │ 3 │ + /// │ 4 │ └─────┘ + /// │ 3 │─ ─ ─ ─ ┘ + /// └─────┘ + /// in_mem_batches + /// ``` + /// + /// # Larger datasets + /// + /// For larger datasets, the batches are first sorted individually + /// and then merged together. + /// + /// ```text + /// ┌─────┐ ┌─────┐ + /// │ 2 │ │ 1 │ + /// │ 3 │ │ 2 │ + /// │ 1 │─ ─▶ sort ─ ─▶│ 2 │─ ─ ─ ─ ─ ┐ + /// │ 4 │ │ 3 │ + /// │ 2 │ │ 4 │ │ + /// └─────┘ └─────┘ sorted output + /// ┌─────┐ ┌─────┐ ▼ stream + /// │ 1 │ │ 1 │ + /// │ 4 │─ ▶ sort ─ ─ ▶│ 1 ├ ─ ─ ▶ merge ─ ─ ─ ─▶ + /// │ 1 │ │ 4 │ + /// └─────┘ └─────┘ ▲ + /// ... ... ... │ + /// + /// ┌─────┐ ┌─────┐ │ + /// │ 4 │ │ 3 │ + /// │ 3 │─ ▶ sort ─ ─ ▶│ 4 │─ ─ ─ ─ ─ ┘ + /// └─────┘ └─────┘ + /// + /// in_mem_batches + /// ``` fn in_mem_sort_stream( &mut self, metrics: BaselineMetrics, @@ -296,6 +488,7 @@ impl ExternalSorter { ) } + /// Sorts a single `RecordBatch` into a single stream fn sort_batch_stream( &self, batch: RecordBatch, @@ -417,8 +610,8 @@ fn read_spill(sender: Sender>, path: &Path) -> Result<()> { /// Sort execution plan. /// -/// This operator supports sorting datasets that are larger than the -/// memory allotted by the memory manager, by spilling to disk. +/// Support sorting datasets that are larger than the memory allotted +/// by the memory manager, by spilling to disk. #[derive(Debug)] pub struct SortExec { /// Input schema @@ -556,9 +749,7 @@ impl ExecutionPlan for SortExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] { - Err(DataFusionError::Plan( - "Sort Error: Can not sort unbounded inputs.".to_string(), - )) + plan_err!("Sort Error: Can not sort unbounded inputs.") } else { Ok(false) } diff --git a/datafusion/core/src/physical_plan/streaming.rs b/datafusion/core/src/physical_plan/streaming.rs index 58f4596a6ed5..28623f299541 100644 --- a/datafusion/core/src/physical_plan/streaming.rs +++ b/datafusion/core/src/physical_plan/streaming.rs @@ -24,7 +24,7 @@ use arrow::datatypes::SchemaRef; use async_trait::async_trait; use futures::stream::StreamExt; -use datafusion_common::{DataFusionError, Result, Statistics}; +use datafusion_common::{plan_err, DataFusionError, Result, Statistics}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use log::debug; @@ -69,9 +69,7 @@ impl StreamingTableExec { "target schema does not contain partition schema. \ Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" ); - return Err(DataFusionError::Plan( - "Mismatch between schema and batches".to_string(), - )); + return plan_err!("Mismatch between schema and batches"); } } diff --git a/datafusion/core/src/physical_plan/values.rs b/datafusion/core/src/physical_plan/values.rs index 70e00ed0340e..757e020350f3 100644 --- a/datafusion/core/src/physical_plan/values.rs +++ b/datafusion/core/src/physical_plan/values.rs @@ -26,7 +26,7 @@ use crate::physical_plan::{ use arrow::array::new_null_array; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::ScalarValue; +use datafusion_common::{plan_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_execution::TaskContext; use std::any::Any; @@ -48,7 +48,7 @@ impl ValuesExec { data: Vec>>, ) -> Result { if data.is_empty() { - return Err(DataFusionError::Plan("Values list cannot be empty".into())); + return plan_err!("Values list cannot be empty"); } let n_row = data.len(); let n_col = schema.fields().len(); @@ -72,9 +72,9 @@ impl ValuesExec { ScalarValue::try_from_array(&a, 0) } Ok(ColumnarValue::Array(a)) => { - Err(DataFusionError::Plan(format!( + plan_err!( "Cannot have array values {a:?} in a values list" - ))) + ) } Err(err) => Err(err), } diff --git a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs index 1f0da4a8e6ab..4889f667f3b3 100644 --- a/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/bounded_window_agg_exec.rs @@ -31,7 +31,7 @@ use crate::physical_plan::{ ColumnStatistics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; -use datafusion_common::Result; +use datafusion_common::{plan_err, Result}; use datafusion_execution::TaskContext; use ahash::RandomState; @@ -611,9 +611,9 @@ impl LinearSearch { .iter() .map(|item| match item.evaluate(record_batch)? { ColumnarValue::Array(array) => Ok(array), - ColumnarValue::Scalar(scalar) => Err(DataFusionError::Plan(format!( - "Sort operation is not applicable to scalar value {scalar}" - ))), + ColumnarValue::Scalar(scalar) => { + plan_err!("Sort operation is not applicable to scalar value {scalar}") + } }) .collect() } diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 7e7fc22965d7..83979af2f43d 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -39,8 +39,8 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::utils::{evaluate_partition_ranges, get_at_indices}; -use datafusion_common::DataFusionError; use datafusion_common::Result; +use datafusion_common::{plan_err, DataFusionError}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{OrderingEquivalenceProperties, PhysicalSortRequirement}; use futures::stream::Stream; @@ -176,10 +176,9 @@ impl ExecutionPlan for WindowAggExec { /// infinite, returns an error to indicate this. fn unbounded_output(&self, children: &[bool]) -> Result { if children[0] { - Err(DataFusionError::Plan( + plan_err!( "Window Error: Windowing is not currently support for unbounded inputs." - .to_string(), - )) + ) } else { Ok(false) } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d0019b4fa37a..6b868b9b2424 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -62,10 +62,11 @@ use crate::{ use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use async_trait::async_trait; -use datafusion_common::{DFSchema, ScalarValue}; +use datafusion_common::{plan_err, DFSchema, ScalarValue}; use datafusion_expr::expr::{ self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast, - GetIndexedField, GroupingSet, InList, Like, ScalarUDF, TryCast, WindowFunction, + GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, ScalarUDF, TryCast, + WindowFunction, }; use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -181,9 +182,22 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let expr = create_physical_name(expr, false)?; Ok(format!("{expr} IS NOT UNKNOWN")) } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { expr, field }) => { let expr = create_physical_name(expr, false)?; - Ok(format!("{expr}[{key}]")) + let name = match field { + GetFieldAccess::NamedStructField { name } => format!("{expr}[{name}]"), + GetFieldAccess::ListIndex { key } => { + let key = create_physical_name(key, false)?; + format!("{expr}[{key}]") + } + GetFieldAccess::ListRange { start, stop } => { + let start = create_physical_name(start, false)?; + let stop = create_physical_name(stop, false)?; + format!("{expr}[{start}:{stop}]") + } + }; + + Ok(name) } Expr::ScalarFunction(func) => { create_function_physical_name(&func.fun.to_string(), false, &func.args) @@ -532,7 +546,7 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Dml(DmlStatement { table_name, - op: WriteOp::Insert, + op: WriteOp::InsertInto, input, .. }) => { @@ -540,7 +554,24 @@ impl DefaultPhysicalPlanner { let schema = session_state.schema_for_ref(table_name)?; if let Some(provider) = schema.table(name).await { let input_exec = self.create_initial_plan(input, session_state).await?; - provider.insert_into(session_state, input_exec).await + provider.insert_into(session_state, input_exec, false).await + } else { + return Err(DataFusionError::Execution(format!( + "Table '{table_name}' does not exist" + ))); + } + } + LogicalPlan::Dml(DmlStatement { + table_name, + op: WriteOp::InsertOverwrite, + input, + .. + }) => { + let name = table_name.table(); + let schema = session_state.schema_for_ref(table_name)?; + if let Some(provider) = schema.table(name).await { + let input_exec = self.create_initial_plan(input, session_state).await?; + provider.insert_into(session_state, input_exec, true).await } else { return Err(DataFusionError::Execution(format!( "Table '{table_name}' does not exist" @@ -707,7 +738,7 @@ impl DefaultPhysicalPlanner { groups.clone(), aggregates.clone(), filters.clone(), - order_bys.clone(), + order_bys, input_exec, physical_input_schema.clone(), )?); @@ -719,6 +750,14 @@ impl DefaultPhysicalPlanner { && session_state.config().target_partitions() > 1 && session_state.config().repartition_aggregations(); + // Some aggregators may be modified during initialization for + // optimization purposes. For example, a FIRST_VALUE may turn + // into a LAST_VALUE with the reverse ordering requirement. + // To reflect such changes to subsequent stages, use the updated + // `AggregateExpr`/`PhysicalSortExpr` objects. + let updated_aggregates = initial_aggr.aggr_expr.clone(); + let updated_order_bys = initial_aggr.order_by_expr.clone(); + let (initial_aggr, next_partition_mode): ( Arc, AggregateMode, @@ -742,9 +781,9 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(AggregateExec::try_new( next_partition_mode, final_grouping_set, - aggregates, + updated_aggregates, filters, - order_bys, + updated_order_bys, initial_aggr, physical_input_schema.clone(), )?)) @@ -1198,19 +1237,20 @@ impl DefaultPhysicalPlanner { ).await?; } - let plan = maybe_plan.ok_or_else(|| DataFusionError::Plan(format!( - "No installed planner was able to convert the custom node to an execution plan: {:?}", e.node - )))?; + let plan = match maybe_plan { + Some(v) => Ok(v), + _ => plan_err!("No installed planner was able to convert the custom node to an execution plan: {:?}", e.node) + }?; // Ensure the ExecutionPlan's schema matches the // declared logical schema to catch and warn about // logic errors when creating user defined plans. if !e.node.schema().matches_arrow_schema(&plan.schema()) { - Err(DataFusionError::Plan(format!( + plan_err!( "Extension planner for {:?} created an ExecutionPlan with mismatched schema. \ LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}", e.node, e.node.schema(), plan.schema() - ))) + ) } else { Ok(plan) } @@ -1547,10 +1587,10 @@ pub fn create_window_expr_with_name( }) .collect::>>()?; if !is_window_valid(window_frame) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", window_frame.start_bound, window_frame.end_bound - ))); + ); } let window_frame = Arc::new(window_frame.clone()); @@ -1564,9 +1604,7 @@ pub fn create_window_expr_with_name( physical_input_schema, ) } - other => Err(DataFusionError::Plan(format!( - "Invalid window expression '{other:?}'" - ))), + other => plan_err!("Invalid window expression '{other:?}'"), } } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 12ccc925990e..bfdb2bda1bb5 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -42,9 +42,9 @@ use datafusion_execution::config::SessionConfig; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::Expr::Wildcard; use datafusion_expr::{ - avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, scalar_subquery, - sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + array_agg, avg, col, count, exists, expr, in_subquery, lit, max, out_ref_col, + scalar_subquery, sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunction, }; use datafusion_physical_expr::var_provider::{VarProvider, VarType}; @@ -1521,3 +1521,23 @@ async fn use_var_provider() -> Result<()> { dataframe.collect().await?; Ok(()) } + +#[tokio::test] +async fn test_array_agg() -> Result<()> { + let df = create_test_table("test") + .await? + .aggregate(vec![], vec![array_agg(col("a"))])?; + + let results = df.collect().await?; + + let expected = vec![ + "+-------------------------------------+", + "| ARRAY_AGG(test.a) |", + "+-------------------------------------+", + "| [abcDEF, abc123, CBAdef, 123AbcDef] |", + "+-------------------------------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 5b18d616b3f9..a7cff6cbd758 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -19,12 +19,13 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::streaming::PartitionStream; use futures::StreamExt; use std::sync::Arc; use datafusion::datasource::streaming::StreamingTable; -use datafusion::datasource::MemTable; +use datafusion::datasource::{MemTable, TableProvider}; use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; @@ -46,19 +47,20 @@ fn init() { #[tokio::test] async fn oom_sort() { - run_limit_test( + TestCase::new( "select * from t order by host DESC", vec![ "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", ], 200_000, ) - .await + .run() + .await } #[tokio::test] async fn group_by_none() { - run_limit_test( + TestCase::new( "select median(image) from t", vec![ "Resources exhausted: Failed to allocate additional", @@ -66,12 +68,13 @@ async fn group_by_none() { ], 20_000, ) + .run() .await } #[tokio::test] async fn group_by_row_hash() { - run_limit_test( + TestCase::new( "select count(*) from t GROUP BY response_bytes", vec![ "Resources exhausted: Failed to allocate additional", @@ -79,12 +82,13 @@ async fn group_by_row_hash() { ], 2_000, ) + .run() .await } #[tokio::test] async fn group_by_hash() { - run_limit_test( + TestCase::new( // group by dict column "select count(*) from t GROUP BY service, host, pod, container", vec![ @@ -93,42 +97,45 @@ async fn group_by_hash() { ], 1_000, ) + .run() .await } #[tokio::test] async fn join_by_key_multiple_partitions() { let config = SessionConfig::new().with_target_partitions(2); - run_limit_test_with_config( + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", vec![ "Resources exhausted: Failed to allocate additional", "HashJoinInput[0]", ], 1_000, - config, ) + .with_config(config) + .run() .await } #[tokio::test] async fn join_by_key_single_partition() { let config = SessionConfig::new().with_target_partitions(1); - run_limit_test_with_config( + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", vec![ "Resources exhausted: Failed to allocate additional", "HashJoinInput", ], 1_000, - config, ) + .with_config(config) + .run() .await } #[tokio::test] async fn join_by_expression() { - run_limit_test( + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.service != t2.service", vec![ "Resources exhausted: Failed to allocate additional", @@ -136,12 +143,13 @@ async fn join_by_expression() { ], 1_000, ) + .run() .await } #[tokio::test] async fn cross_join() { - run_limit_test( + TestCase::new( "select t1.* from t t1 CROSS JOIN t t2", vec![ "Resources exhausted: Failed to allocate additional", @@ -149,6 +157,7 @@ async fn cross_join() { ], 1_000, ) + .run() .await } @@ -159,94 +168,185 @@ async fn merge_join() { .with_target_partitions(2) .set_bool("datafusion.optimizer.prefer_hash_join", false); - run_limit_test_with_config( + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", vec![ "Resources exhausted: Failed to allocate additional", "SMJStream", ], 1_000, - config, ) + .with_config(config) + .run() .await } #[tokio::test] -async fn test_limit_symmetric_hash_join() { - let config = SessionConfig::new(); - - run_streaming_test_with_config( +async fn symmetric_hash_join() { + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", vec![ "Resources exhausted: Failed to allocate additional", "SymmetricHashJoinStream", ], 1_000, - config, ) + .with_scenario(Scenario::AccessLogStreaming) + .run() .await } -/// 50 byte memory limit -const MEMORY_FRACTION: f64 = 0.95; - -/// runs the specified query against 1000 rows with specified -/// memory limit and no disk manager enabled with default SessionConfig. -async fn run_limit_test( - query: &str, - expected_error_contains: Vec<&str>, +/// Run the query with the specified memory limit, +/// and verifies the expected errors are returned +#[derive(Clone, Debug)] +struct TestCase { + query: String, + expected_errors: Vec, memory_limit: usize, -) { - let config = SessionConfig::new(); - run_limit_test_with_config(query, expected_error_contains, memory_limit, config).await + config: SessionConfig, + scenario: Scenario, } -/// runs the specified query against 1000 rows with a 50 -/// byte memory limit and no disk manager enabled -/// with specified SessionConfig instance -async fn run_limit_test_with_config( - query: &str, - expected_error_contains: Vec<&str>, - memory_limit: usize, - config: SessionConfig, -) { - let batches: Vec<_> = AccessLogGenerator::new() - .with_row_limit(1000) - .with_max_batch_size(50) - .collect(); +impl TestCase { + fn new<'a>( + query: impl Into, + expected_errors: impl IntoIterator, + memory_limit: usize, + ) -> Self { + let expected_errors: Vec = + expected_errors.into_iter().map(|s| s.to_string()).collect(); + + Self { + query: query.into(), + expected_errors, + memory_limit, + config: SessionConfig::new(), + scenario: Scenario::AccessLog, + } + } - let table = MemTable::try_new(batches[0].schema(), vec![batches]).unwrap(); + /// Specify the configuration to use + pub fn with_config(mut self, config: SessionConfig) -> Self { + self.config = config; + self + } - let rt_config = RuntimeConfig::new() - // do not allow spilling - .with_disk_manager(DiskManagerConfig::Disabled) - .with_memory_limit(memory_limit, MEMORY_FRACTION); + /// Specify the scenario to run + pub fn with_scenario(mut self, scenario: Scenario) -> Self { + self.scenario = scenario; + self + } + + /// Run the test, panic'ing on error + async fn run(self) { + let Self { + query, + expected_errors, + memory_limit, + config, + scenario, + } = self; + + let table = scenario.table(); - let runtime = RuntimeEnv::new(rt_config).unwrap(); + let rt_config = RuntimeConfig::new() + // do not allow spilling + .with_disk_manager(DiskManagerConfig::Disabled) + .with_memory_limit(memory_limit, MEMORY_FRACTION); - // Disabling physical optimizer rules to avoid sorts / repartitions - // (since RepartitionExec / SortExec also has a memory budget which we'll likely hit first) - let state = SessionState::with_config_rt(config, Arc::new(runtime)) - .with_physical_optimizer_rules(vec![]); + let runtime = RuntimeEnv::new(rt_config).unwrap(); - let ctx = SessionContext::with_state(state); - ctx.register_table("t", Arc::new(table)) - .expect("registering table"); + // Configure execution + let state = SessionState::with_config_rt(config, Arc::new(runtime)) + .with_physical_optimizer_rules(scenario.rules()); - let df = ctx.sql(query).await.expect("Planning query"); + let ctx = SessionContext::with_state(state); + ctx.register_table("t", table).expect("registering table"); - match df.collect().await { - Ok(_batches) => { - panic!("Unexpected success when running, expected memory limit failure") + let df = ctx.sql(&query).await.expect("Planning query"); + + match df.collect().await { + Ok(_batches) => { + panic!("Unexpected success when running, expected memory limit failure") + } + Err(e) => { + for error_substring in expected_errors { + assert_contains!(e.to_string(), error_substring); + } + } } - Err(e) => { - for error_substring in expected_error_contains { - assert_contains!(e.to_string(), error_substring); + } +} + +/// 50 byte memory limit +const MEMORY_FRACTION: f64 = 0.95; + +/// Different data scenarios +#[derive(Clone, Debug)] +enum Scenario { + /// 1000 rows of access log data with batches of 50 rows + AccessLog, + + /// 1000 rows of access log data with batches of 50 rows in a + /// [`StreamingTable`] + AccessLogStreaming, +} + +impl Scenario { + /// return a TableProvider with data for the test + fn table(&self) -> Arc { + match self { + Self::AccessLog => { + let batches = access_log_batches(); + let table = + MemTable::try_new(batches[0].schema(), vec![batches]).unwrap(); + Arc::new(table) + } + Self::AccessLogStreaming => { + let batches = access_log_batches(); + + // Create a new streaming table with the generated schema and batches + let table = StreamingTable::try_new( + batches[0].schema(), + vec![Arc::new(DummyStreamPartition { + schema: batches[0].schema(), + batches: batches.clone(), + })], + ) + .unwrap() + .with_infinite_table(true); + Arc::new(table) + } + } + } + + /// return the optimizer rules to use + fn rules(&self) -> Vec> { + match self { + Self::AccessLog => { + // Disabling physical optimizer rules to avoid sorts / + // repartitions (since RepartitionExec / SortExec also + // has a memory budget which we'll likely hit first) + vec![] + } + Self::AccessLogStreaming => { + // Disable all physical optimizer rules except the + // JoinSelection rule to avoid sorts or repartition, + // as they also have memory budgets that may be hit + // first + vec![Arc::new(JoinSelection::new())] } } } } +fn access_log_batches() -> Vec { + AccessLogGenerator::new() + .with_row_limit(1000) + .with_max_batch_size(50) + .collect() +} + struct DummyStreamPartition { schema: SchemaRef, batches: Vec, @@ -266,66 +366,3 @@ impl PartitionStream for DummyStreamPartition { )) } } - -async fn run_streaming_test_with_config( - query: &str, - expected_error_contains: Vec<&str>, - memory_limit: usize, - config: SessionConfig, -) { - // Generate a set of access logs with a row limit of 1000 and a max batch size of 50 - let batches: Vec<_> = AccessLogGenerator::new() - .with_row_limit(1000) - .with_max_batch_size(50) - .collect(); - - // Create a new streaming table with the generated schema and batches - let table = StreamingTable::try_new( - batches[0].schema(), - vec![Arc::new(DummyStreamPartition { - schema: batches[0].schema(), - batches: batches.clone(), - })], - ) - .unwrap() - .with_infinite_table(true); - - // Configure the runtime environment with custom settings - let rt_config = RuntimeConfig::new() - // Disable disk manager to disallow spilling - .with_disk_manager(DiskManagerConfig::Disabled) - // Set memory limit to 50 bytes - .with_memory_limit(memory_limit, MEMORY_FRACTION); - - // Create a new runtime environment with the configured settings - let runtime = RuntimeEnv::new(rt_config).unwrap(); - - // Create a new session state with the given configuration and runtime environment - // Disable all physical optimizer rules except the PipelineFixer rule to avoid sorts or - // repartition, as they also have memory budgets that may be hit first - let state = SessionState::with_config_rt(config, Arc::new(runtime)) - .with_physical_optimizer_rules(vec![Arc::new(JoinSelection::new())]); - - // Create a new session context with the session state - let ctx = SessionContext::with_state(state); - // Register the streaming table with the session context - ctx.register_table("t", Arc::new(table)) - .expect("registering table"); - - // Execute the SQL query and get a DataFrame - let df = ctx.sql(query).await.expect("Planning query"); - - // Collect the results of the DataFrame execution - match df.collect().await { - // If the execution succeeds, panic as we expect memory limit failure - Ok(_batches) => { - panic!("Unexpected success when running, expected memory limit failure") - } - // If the execution fails, verify if the error contains the expected substrings - Err(e) => { - for error_substring in expected_error_contains { - assert_contains!(e.to_string(), error_substring); - } - } - } -} diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 5444b3a88f05..fc0a4e7c7ed2 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -60,6 +60,9 @@ async fn test_mathematical_expressions_with_null() -> Result<()> { test_expression!("atan2(NULL, NULL)", "NULL"); test_expression!("atan2(1, NULL)", "NULL"); test_expression!("atan2(NULL, 1)", "NULL"); + test_expression!("nanvl(NULL, NULL)", "NULL"); + test_expression!("nanvl(1, NULL)", "NULL"); + test_expression!("nanvl(NULL, 1)", "NULL"); Ok(()) } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 85a806428548..c1adcf9d0a96 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -40,6 +40,7 @@ use datafusion::{ }; use datafusion::{execution::context::SessionContext, physical_plan::displayable}; use datafusion_common::cast::as_float64_array; +use datafusion_common::plan_err; use datafusion_common::{assert_contains, assert_not_contains}; use datafusion_expr::Volatility; use object_store::path::Path; @@ -405,10 +406,7 @@ async fn register_tpch_csv_data( DataType::Decimal128(_, _) => { cols.push(Box::new(Decimal128Builder::with_capacity(records.len()))) } - _ => { - let msg = format!("Not implemented: {}", field.data_type()); - Err(DataFusionError::Plan(msg))? - } + _ => plan_err!("Not implemented: {}", field.data_type())?, } } @@ -446,10 +444,7 @@ async fn register_tpch_csv_data( let value_i128 = val.parse::().unwrap(); sb.append_value(value_i128); } - _ => Err(DataFusionError::Plan(format!( - "Not implemented: {}", - field.data_type() - )))?, + _ => plan_err!("Not implemented: {}", field.data_type())?, } } } diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/util.rs b/datafusion/core/tests/sqllogictests/src/engines/datafusion/util.rs deleted file mode 100644 index 424a297f24ec..000000000000 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/util.rs +++ /dev/null @@ -1,55 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::DataType; -use datafusion_common::config::ConfigOptions; -use datafusion_common::TableReference; -use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; -use datafusion_sql::planner::ContextProvider; -use std::sync::Arc; - -pub struct LogicTestContextProvider {} - -// Only a mock, don't need to implement -impl ContextProvider for LogicTestContextProvider { - fn get_table_provider( - &self, - _name: TableReference, - ) -> datafusion_common::Result> { - todo!() - } - - fn get_function_meta(&self, _name: &str) -> Option> { - todo!() - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - todo!() - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - todo!() - } - - fn options(&self) -> &ConfigOptions { - todo!() - } - - fn get_window_meta(&self, _name: &str) -> Option> { - todo!() - } -} diff --git a/datafusion/core/tests/sqllogictests/src/main.rs b/datafusion/core/tests/sqllogictests/src/main.rs index c74d1cb11a47..7d23971fbfca 100644 --- a/datafusion/core/tests/sqllogictests/src/main.rs +++ b/datafusion/core/tests/sqllogictests/src/main.rs @@ -20,6 +20,7 @@ use std::path::{Path, PathBuf}; #[cfg(target_family = "windows")] use std::thread; +use datafusion_sqllogictest::{DataFusion, Postgres}; use futures::stream::StreamExt; use log::info; use sqllogictest::strict_column_validator; @@ -28,10 +29,6 @@ use tempfile::TempDir; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{DataFusionError, Result}; -use crate::engines::datafusion::DataFusion; -use crate::engines::postgres::Postgres; - -mod engines; mod setup; const TEST_DIRECTORY: &str = "tests/sqllogictests/test_files/"; diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt index 2f6f44c56be6..1780ccab7675 100644 --- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt +++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt @@ -2290,3 +2290,272 @@ true false true NULL + + + +# +# regr_*() tests +# + +# regr_*() invalid input +statement error +select regr_slope(); + +statement error +select regr_intercept(*); + +statement error +select regr_count(*) from aggregate_test_100; + +statement error +select regr_r2(1); + +statement error +select regr_avgx(1,2,3); + +statement error +select regr_avgy(1, 'foo'); + +statement error +select regr_sxx('foo', 1); + +statement error +select regr_syy('foo', 'bar'); + +statement error +select regr_sxy(NULL, 'bar'); + + + +# regr_*() NULL results +query RRRRRRRRR +select regr_slope(1,1), regr_intercept(1,1), regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), regr_sxy(1,1); +---- +NULL NULL 1 NULL 1 1 0 0 0 + +query RRRRRRRRR +select regr_slope(1, NULL), regr_intercept(1, NULL), regr_count(1, NULL), regr_r2(1, NULL), regr_avgx(1, NULL), regr_avgy(1, NULL), regr_sxx(1, NULL), regr_syy(1, NULL), regr_sxy(1, NULL); +---- +NULL NULL 0 NULL NULL NULL NULL NULL NULL + +query RRRRRRRRR +select regr_slope(NULL, 1), regr_intercept(NULL, 1), regr_count(NULL, 1), regr_r2(NULL, 1), regr_avgx(NULL, 1), regr_avgy(NULL, 1), regr_sxx(NULL, 1), regr_syy(NULL, 1), regr_sxy(NULL, 1); +---- +NULL NULL 0 NULL NULL NULL NULL NULL NULL + +query RRRRRRRRR +select regr_slope(NULL, NULL), regr_intercept(NULL, NULL), regr_count(NULL, NULL), regr_r2(NULL, NULL), regr_avgx(NULL, NULL), regr_avgy(NULL, NULL), regr_sxx(NULL, NULL), regr_syy(NULL, NULL), regr_sxy(NULL, NULL); +---- +NULL NULL 0 NULL NULL NULL NULL NULL NULL + +query RRRRRRRRR +select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), regr_r2(column2, column1), regr_avgx(column2, column1), regr_avgy(column2, column1), regr_sxx(column2, column1), regr_syy(column2, column1), regr_sxy(column2, column1) from (values (1,2), (1,4), (1,6)); +---- +NULL NULL 3 NULL 1 4 0 8 0 + + + +# regr_*() basic tests +query RRRRRRRRR +select + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,2), (2,4), (3,6)); +---- +2 0 3 1 2 4 2 8 4 + +query RRRRRRRRR +select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) +from aggregate_test_100; +---- +0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 + + + +# regr_*() functions ignore NULLs +query RRRRRRRRR +select + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,NULL), (2,4), (3,6)); +---- +2 0 2 1 2.5 5 0.5 2 1 + +query RRRRRRRRR +select + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,NULL), (NULL,4), (3,6)); +---- +NULL NULL 1 NULL 3 6 0 0 0 + +query RRRRRRRRR +select + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,NULL), (NULL,4), (NULL,NULL)); +---- +NULL NULL 0 NULL NULL NULL NULL NULL NULL + +query TRRRRRRRRR rowsort +select + column3, + regr_slope(column2, column1), + regr_intercept(column2, column1), + regr_count(column2, column1), + regr_r2(column2, column1), + regr_avgx(column2, column1), + regr_avgy(column2, column1), + regr_sxx(column2, column1), + regr_syy(column2, column1), + regr_sxy(column2, column1) +from (values (1,2,'a'), (2,4,'a'), (1,3,'b'), (3,9,'b'), (1,10,'c'), (NULL,100,'c')) +group by column3; +---- +a 2 0 2 1 1.5 3 0.5 2 1 +b 3 0 2 1 2 6 2 18 6 +c NULL NULL 1 NULL 1 10 0 0 0 + + + +# regr_*() testing merge_batch() from RegrAccumulator's internal implementation +statement ok +set datafusion.execution.batch_size = 1; + +query RRRRRRRRR +select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) +from aggregate_test_100; +---- +0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 + +statement ok +set datafusion.execution.batch_size = 2; + +query RRRRRRRRR +select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) +from aggregate_test_100; +---- +0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 + +statement ok +set datafusion.execution.batch_size = 3; + +query RRRRRRRRR +select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) +from aggregate_test_100; +---- +0.051534002628 0.48427355347 100 0.001929150558 0.479274948239 0.508972509913 6.707779292571 9.234223721582 0.345678715695 + +statement ok +set datafusion.execution.batch_size = 8192; + + + +# regr_*() testing retract_batch() from RegrAccumulator's internal implementation +query RRRRRRRRR +SELECT + regr_slope(column2, column1) OVER w AS slope, + regr_intercept(column2, column1) OVER w AS intercept, + regr_count(column2, column1) OVER w AS count, + regr_r2(column2, column1) OVER w AS r2, + regr_avgx(column2, column1) OVER w AS avgx, + regr_avgy(column2, column1) OVER w AS avgy, + regr_sxx(column2, column1) OVER w AS sxx, + regr_syy(column2, column1) OVER w AS syy, + regr_sxy(column2, column1) OVER w AS sxy +FROM (VALUES (1,2), (2,4), (3,6), (4,12), (5,15), (6,18)) AS t(column1, column2) +WINDOW w AS (ORDER BY column1 ROWS BETWEEN 2 PRECEDING AND CURRENT ROW); +---- +NULL NULL 1 NULL 1 2 0 0 0 +2 0 2 1 1.5 3 0.5 2 1 +2 0 3 1 2 4 2 8 4 +4 -4.666666666667 3 0.923076923077 3 7.333333333333 2 34.666666666667 8 +4.5 -7 3 0.964285714286 4 11 2 42 9 +3 0 3 1 5 15 2 18 6 + +query RRRRRRRRR +SELECT + regr_slope(column2, column1) OVER w AS slope, + regr_intercept(column2, column1) OVER w AS intercept, + regr_count(column2, column1) OVER w AS count, + regr_r2(column2, column1) OVER w AS r2, + regr_avgx(column2, column1) OVER w AS avgx, + regr_avgy(column2, column1) OVER w AS avgy, + regr_sxx(column2, column1) OVER w AS sxx, + regr_syy(column2, column1) OVER w AS syy, + regr_sxy(column2, column1) OVER w AS sxy +FROM (VALUES (1,2), (2,4), (3,6), (3, NULL), (4, NULL), (5,15), (6,18), (7, 21)) AS t(column1, column2) +WINDOW w AS (ORDER BY column1 ROWS BETWEEN 2 PRECEDING AND CURRENT ROW); +---- +NULL NULL 1 NULL 1 2 0 0 0 +2 0 2 1 1.5 3 0.5 2 1 +2 0 3 1 2 4 2 8 4 +2 0 2 1 2.5 5 0.5 2 1 +NULL NULL 1 NULL 3 6 0 0 0 +NULL NULL 1 NULL 5 15 0 0 0 +3 0 2 1 5.5 16.5 0.5 4.5 1.5 +3 0 3 1 6 18 2 18 6 diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 27d288cf60f0..25e2e4b453a3 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -16,7 +16,7 @@ # under the License. ############# -## Array expressions Tests +## Array Expressions Tests ############# @@ -43,6 +43,20 @@ CREATE TABLE values( (8, 15, 16, 8.8, NULL, '') ; +statement ok +CREATE TABLE values_without_nulls +AS VALUES + (1, 1, 2, 1.1, 'Lorem', 'A'), + (2, 3, 4, 2.2, 'ipsum', ''), + (3, 5, 6, 3.3, 'dolor', 'BB'), + (4, 7, 8, 4.4, 'sit', NULL), + (5, 9, 10, 5.5, 'amet', 'CCC'), + (6, 11, 12, 6.6, ',', 'DD'), + (7, 13, 14, 7.7, 'consectetur', 'E'), + (8, 15, 16, 8.8, 'adipiscing', 'F'), + (9, 17, 18, 9.9, 'elit', '') +; + statement ok CREATE TABLE arrays AS VALUES @@ -55,6 +69,18 @@ AS VALUES (make_array(make_array(15, 16),make_array(NULL, 18)), make_array(16.6, 17.7, 18.8), NULL) ; +statement ok +CREATE TABLE slices +AS VALUES + (make_array(NULL, 2, 3, 4, 5, 6, 7, 8, 9, 10), 1, 1), + (make_array(11, 12, 13, 14, 15, 16, 17, 18, NULL, 20), 2, -4), + (make_array(21, 22, 23, NULL, 25, 26, 27, 28, 29, 30), 0, 0), + (make_array(31, 32, 33, 34, 35, NULL, 37, 38, 39, 40), -4, -7), + (NULL, 4, 5), + (make_array(41, 42, 43, 44, 45, 46, 47, 48, 49, 50), NULL, 6), + (make_array(51, 52, NULL, 54, 55, 56, 57, 58, 59, 60), 5, NULL) +; + statement ok CREATE TABLE nested_arrays AS VALUES @@ -213,6 +239,18 @@ NULL 44 5 @ [51, 52, , 54, 55, 56, 57, 58, 59, 60] 55 NULL ^ [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] 66 7 NULL +# slices table +query ?II +select column1, column2, column3 from slices; +---- +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1 1 +[11, 12, 13, 14, 15, 16, 17, 18, , 20] 2 -4 +[21, 22, 23, , 25, 26, 27, 28, 29, 30] 0 0 +[31, 32, 33, 34, 35, , 37, 38, 39, 40] -4 -7 +NULL 4 5 +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] NULL 6 +[51, 52, , 54, 55, 56, 57, 58, 59, 60] 5 NULL + query ??I? select column1, column2, column3, column4 from arrays_values_v2; ---- @@ -250,6 +288,178 @@ select column1, column2, column3, column4 from nested_arrays_with_repeating_elem [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [19, 20, 21] [28, 29, 30] 5 [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] [28, 29, 30] [37, 38, 39] 10 + +### Array index + + +## array[i] + +# single index with scalars #1 (positive index) +query IRT +select make_array(1, 2, 3)[1], make_array(1.0, 2.0, 3.0)[2], make_array('h', 'e', 'l', 'l', 'o')[3]; +---- +1 2 l + +# single index with scalars #2 (zero index) +query I +select make_array(1, 2, 3)[0]; +---- +NULL + +# single index with scalars #3 (negative index) +query IRT +select make_array(1, 2, 3)[-1], make_array(1.0, 2.0, 3.0)[-2], make_array('h', 'e', 'l', 'l', 'o')[-3]; +---- +3 2 l + +# single index with scalars #4 (complex index) +query IRT +select make_array(1, 2, 3)[1 + 2 - 1], make_array(1.0, 2.0, 3.0)[2 * 1 * 0 - 2], make_array('h', 'e', 'l', 'l', 'o')[2 - 3]; +---- +2 2 o + +# single index with columns #1 (positive index) +query ?RT +select column1[2], column2[3], column3[1] from arrays; +---- +[3, ] 3.3 L +[5, 6] 6.6 i +[7, 8] 9.9 d +[9, 10] 12.2 s +NULL 15.5 a +[13, 14] NULL , +[, 18] 18.8 NULL + +# single index with columns #2 (zero index) +query ?RT +select column1[0], column2[0], column3[0] from arrays; +---- +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL +NULL NULL NULL + +# single index with columns #3 (negative index) +query ?RT +select column1[-2], column2[-3], column3[-1] from arrays; +---- +[, 2] 1.1 m +[3, 4] NULL m +[5, 6] 7.7 r +[7, ] 10.1 t +NULL 13.3 t +[11, 12] NULL , +[15, 16] 16.6 NULL + +# single index with columns #4 (complex index) +query ?RT +select column1[9 - 7], column2[2 * 0], column3[1 - 3] from arrays; +---- +[3, ] NULL e +[5, 6] NULL u +[7, 8] NULL o +[9, 10] NULL i +NULL NULL e +[13, 14] NULL NULL +[, 18] NULL NULL + +# TODO: support index as column +# single index with columns #5 (index as column) +# query ? +# select make_array(1, 2, 3, 4, 5)[column2] from arrays_with_repeating_elements; +# ---- + +# TODO: support argument and index as columns +# single index with columns #6 (argument and index as columns) +# query I +# select column1[column2] from arrays_with_repeating_elements; +# ---- + +## array[i:j] + +# multiple index with columns #1 (positive index) +query ??? +select make_array(1, 2, 3)[1:2], make_array(1.0, 2.0, 3.0)[2:3], make_array('h', 'e', 'l', 'l', 'o')[2:4]; +---- +[1, 2] [2.0, 3.0] [e, l, l] + +# multiple index with columns #2 (zero index) +query ??? +select make_array(1, 2, 3)[0:0], make_array(1.0, 2.0, 3.0)[0:2], make_array('h', 'e', 'l', 'l', 'o')[0:6]; +---- +[] [1.0, 2.0] [h, e, l, l, o] + +# TODO: support multiple negative index +# multiple index with columns #3 (negative index) +# query II +# select make_array(1, 2, 3)[-3:-1], make_array(1.0, 2.0, 3.0)[-3:-1], make_array('h', 'e', 'l', 'l', 'o')[-2:0]; +# ---- + +# TODO: support complex index +# multiple index with columns #4 (complex index) +# query III +# select make_array(1, 2, 3)[2 + 1 - 1:10], make_array(1.0, 2.0, 3.0)[2 | 2:10], make_array('h', 'e', 'l', 'l', 'o')[6 ^ 6:10]; +# ---- + +# multiple index with columns #1 (positive index) +query ??? +select column1[2:4], column2[1:4], column3[3:4] from arrays; +---- +[[3, ]] [1.1, 2.2, 3.3] [r, e] +[[5, 6]] [, 5.5, 6.6] [, u] +[[7, 8]] [7.7, 8.8, 9.9] [l, o] +[[9, 10]] [10.1, , 12.2] [t] +[] [13.3, 14.4, 15.5] [e, t] +[[13, 14]] [] [] +[[, 18]] [16.6, 17.7, 18.8] [] + +# multiple index with columns #2 (zero index) +query ??? +select column1[0:5], column2[0:3], column3[0:9] from arrays; +---- +[[, 2], [3, ]] [1.1, 2.2, 3.3] [L, o, r, e, m] +[[3, 4], [5, 6]] [, 5.5, 6.6] [i, p, , u, m] +[[5, 6], [7, 8]] [7.7, 8.8, 9.9] [d, , l, o, r] +[[7, ], [9, 10]] [10.1, , 12.2] [s, i, t] +[] [13.3, 14.4, 15.5] [a, m, e, t] +[[11, 12], [13, 14]] [] [,] +[[15, 16], [, 18]] [16.6, 17.7, 18.8] [] + +# TODO: support negative index +# multiple index with columns #3 (negative index) +# query ?RT +# select column1[-2:-4], column2[-3:-5], column3[-1:-4] from arrays; +# ---- +# [, 2] 1.1 m + +# TODO: support complex index +# multiple index with columns #4 (complex index) +# query ?RT +# select column1[9 - 7:2 + 2], column2[1 * 0:2 * 3], column3[1 + 1 - 0:5 % 3] from arrays; +# ---- + +# TODO: support first index as column +# multiple index with columns #5 (first index as column) +# query ? +# select make_array(1, 2, 3, 4, 5)[column2:4] from arrays_with_repeating_elements +# ---- + +# TODO: support last index as column +# multiple index with columns #6 (last index as column) +# query ?RT +# select make_array(1, 2, 3, 4, 5)[2:column3] from arrays_with_repeating_elements; +# ---- + +# TODO: support argument and indices as column +# multiple index with columns #7 (argument and indices as column) +# query ?RT +# select column1[column2:column3] from arrays_with_repeating_elements; +# ---- + + ### Array function tests @@ -363,23 +573,6 @@ select make_array(a, b, c, d) from values; [7.0, 13.0, 14.0, ] [8.0, 15.0, 16.0, 8.8] -# make_array null handling -query ?B?BB -select - make_array(a), make_array(a)[1] IS NULL, - make_array(e, f), make_array(e, f)[1] IS NULL, make_array(e, f)[2] IS NULL -from values; ----- -[1] false [Lorem, A] false false -[2] false [ipsum, ] false false -[3] false [dolor, BB] false false -[4] false [sit, ] false true -[] true [amet, CCC] false false -[5] false [,, DD] false false -[6] false [consectetur, E] false false -[7] false [adipiscing, F] false false -[8] false [, ] true false - # make_array with column of list query ?? select column1, column5 from arrays_values_without_nulls; @@ -400,6 +593,257 @@ from arrays_values_without_nulls; [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [6, 7]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [50, 51, 52]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [8, 9]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [50, 51, 52]] +## array_element (aliases: array_extract, list_extract, list_element) + +# array_element scalar function #1 (with positive index) +query IT +select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# array_element scalar function #2 (with positive index; out of bounds) +query IT +select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11); +---- +NULL NULL + +# array_element scalar function #3 (with zero) +query IT +select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0); +---- +NULL NULL + +# array_element scalar function #4 (with NULL) +query error +select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); +---- +NULL NULL + +# array_element scalar function #5 (with negative index) +query IT +select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3); +---- +4 l + +# array_element scalar function #6 (with negative index; out of bounds) +query IT +select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7); +---- +NULL NULL + +# array_element scalar function #7 (nested array) +query ? +select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1); +---- +[1, 2, 3, 4, 5] + +# array_extract scalar function #8 (function alias `array_slice`) +query IT +select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# list_element scalar function #9 (function alias `array_slice`) +query IT +select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# list_extract scalar function #10 (function alias `array_slice`) +query IT +select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); +---- +2 l + +# array_element with columns +query I +select array_element(column1, column2) from slices; +---- +NULL +12 +NULL +37 +NULL +NULL +55 + +# array_element with columns and scalars +query II +select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices; +---- +1 3 +2 13 +NULL 23 +2 33 +4 NULL +NULL 43 +5 NULL + +## array_slice (aliases: list_slice) + +# array_slice scalar function #1 (with positive indexes) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); +---- +[2, 3, 4] [h, e] + +# array_slice scalar function #2 (with positive indexes; full array) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 5); +---- +[1, 2, 3, 4, 5] [h, e, l, l, o] + +# array_slice scalar function #3 (with positive indexes; first index = second index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 4, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 3); +---- +[4] [l] + +# array_slice scalar function #4 (with positive indexes; first index > second_index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 4, 1); +---- +[] [] + +# array_slice scalar function #5 (with positive indexes; out of bounds) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, 6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, 7); +---- +[2, 3, 4, 5] [l, l, o] + +# array_slice scalar function #6 (with positive indexes; nested array) +query ? +select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1, 1); +---- +[[1, 2, 3, 4, 5]] + +# array_slice scalar function #7 (with zero and positive number) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 3); +---- +[1, 2, 3, 4] [h, e, l] + +# array_slice scalar function #8 (with NULL and positive number) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3); +---- +[1, 2, 3, 4] [h, e, l] + +# array_slice scalar function #9 (with positive number and NULL) +query error +select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +---- +[2, 3, 4, 5] [l, l, o] + +# array_slice scalar function #10 (with zero-zero) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, 0); +---- +[] [] + +# array_slice scalar function #11 (with NULL-NULL) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL); +---- +[] [] + +# array_slice scalar function #12 (with zero and negative number) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 0, -3); +---- +[1] [h, e] + +# array_slice scalar function #13 (with negative number and NULL) +query error +select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL); +---- +[2, 3, 4, 5] [l, l, o] + +# array_slice scalar function #14 (with NULL and negative number) +query error +select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3); +---- +[1] [h, e] + +# array_slice scalar function #15 (with negative indexes) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -4, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -1); +---- +[2, 3, 4] [l, l] + +# array_slice scalar function #16 (with negative indexes; almost full array (only with negative indices cannot return full array)) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -5, -1), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -5, -1); +---- +[1, 2, 3, 4] [h, e, l, l] + +# array_slice scalar function #17 (with negative indexes; first index = second index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -4, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -3); +---- +[] [] + +# array_slice scalar function #18 (with negative indexes; first index > second_index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -4, -6), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, -6); +---- +[] [] + +# array_slice scalar function #19 (with negative indexes; out of bounds) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3); +---- +[] [] + +# array_slice scalar function #20 (with negative indexes; nested array) +query ? +select array_slice(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), -2, -1); +---- +[[1, 2, 3, 4, 5]] + +# array_slice scalar function #21 (with first positive index and last negative index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), 2, -3), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, -2); +---- +[2] [e, l] + +# array_slice scalar function #22 (with first negative index and last positive index) +query ?? +select array_slice(make_array(1, 2, 3, 4, 5), -2, 5), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, 4); +---- +[4, 5] [l, l] + +# list_slice scalar function #23 (function alias `array_slice`) +query ?? +select list_slice(make_array(1, 2, 3, 4, 5), 2, 4), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 1, 2); +---- +[2, 3, 4] [h, e] + +# array_slice with columns +query ? +select array_slice(column1, column2, column3) from slices; +---- +[] +[12, 13, 14, 15, 16] +[] +[] +[] +[41, 42, 43, 44, 45, 46] +[55, 56, 57, 58, 59, 60] + +# TODO: support NULLS in output instead of `[]` +# array_slice with columns and scalars +query ??? +select array_slice(make_array(1, 2, 3, 4, 5), column2, column3), array_slice(column1, 3, column3), array_slice(column1, column2, 5) from slices; +---- +[1] [] [, 2, 3, 4, 5] +[] [13, 14, 15, 16] [12, 13, 14, 15] +[] [] [21, 22, 23, , 25] +[] [33] [] +[4, 5] [] [] +[1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] +[5] [, 54, 55, 56, 57, 58, 59, 60] [55] + ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) # array_append scalar function #1 @@ -566,25 +1010,71 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma [[1, 11, 111], [1, 2, 3], [2, 9, 1], [7, 8, 9], [1, 2, 3], [1, 7, 4], [4, 5, 6]] [[7, 8, 9], [1, 2, 3], [11, 12, 13]] [[1, 11, 111], [4, 5, 6], [10, 11, 12], [4, 9, 8], [7, 8, 9], [10, 11, 12], [1, 8, 7]] [[10, 11, 12], [1, 2, 3], [11, 12, 13]] -## array_fill +## array_repeat (aliases: `list_repeat`) -# array_fill scalar function #1 +# array_repeat scalar function #1 query ??? -select array_fill(11, make_array(1, 2, 3)), array_fill(3, make_array(2, 3)), array_fill(2, make_array(2)); +select array_repeat(1, 5), array_repeat(3.14, 3), array_repeat('l', 4); ---- -[[[11, 11, 11], [11, 11, 11]]] [[3, 3, 3], [3, 3, 3]] [2, 2] +[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] -# array_fill scalar function #2 -query ?? -select array_fill(1, make_array(1, 1, 1)), array_fill(2, make_array(2, 2, 2, 2, 2)); +# array_repeat scalar function #2 (element as list) +query ??? +select array_repeat([1], 5), array_repeat([1.1, 2.2, 3.3], 3), array_repeat([[1, 2], [3, 4]], 2); ---- -[[[1]]] [[[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]], [[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]]] +[[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] -# array_fill scalar function #3 +# list_repeat scalar function #3 (function alias: `array_repeat`) +query ??? +select list_repeat(1, 5), list_repeat(3.14, 3), list_repeat('l', 4); +---- +[1, 1, 1, 1, 1] [3.14, 3.14, 3.14] [l, l, l, l] + +# array_repeat with columns #1 +query ? +select array_repeat(column4, column1) from values_without_nulls; +---- +[1.1] +[2.2, 2.2] +[3.3, 3.3, 3.3] +[4.4, 4.4, 4.4, 4.4] +[5.5, 5.5, 5.5, 5.5, 5.5] +[6.6, 6.6, 6.6, 6.6, 6.6, 6.6] +[7.7, 7.7, 7.7, 7.7, 7.7, 7.7, 7.7] +[8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8, 8.8] +[9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9, 9.9] + +# array_repeat with columns #2 (element as list) query ? -select array_fill(1, make_array()) +select array_repeat(column1, column3) from arrays_values_without_nulls; ---- -[] +[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] +[[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] +[[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] +[[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] + +# array_repeat with columns and scalars #1 +query ?? +select array_repeat(1, column1), array_repeat(column4, 3) from values_without_nulls; +---- +[1] [1.1, 1.1, 1.1] +[1, 1] [2.2, 2.2, 2.2] +[1, 1, 1] [3.3, 3.3, 3.3] +[1, 1, 1, 1] [4.4, 4.4, 4.4] +[1, 1, 1, 1, 1] [5.5, 5.5, 5.5] +[1, 1, 1, 1, 1, 1] [6.6, 6.6, 6.6] +[1, 1, 1, 1, 1, 1, 1] [7.7, 7.7, 7.7] +[1, 1, 1, 1, 1, 1, 1, 1] [8.8, 8.8, 8.8] +[1, 1, 1, 1, 1, 1, 1, 1, 1] [9.9, 9.9, 9.9] + +# array_repeat with columns and scalars #2 (element as list) +query ?? +select array_repeat([1], column3), array_repeat(column1, 3) from arrays_values_without_nulls; +---- +[[1]] [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] +[[1], [1]] [[11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [11, 12, 13, 14, 15, 16, 17, 18, 19, 20]] +[[1], [1], [1]] [[21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]] +[[1], [1], [1], [1]] [[31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40], [31, 32, 33, 34, 35, 26, 37, 38, 39, 40]] ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) @@ -1140,7 +1630,7 @@ h,e,l,l,o 1-2-3-4-5 1|2|3 # array_to_string scalar function #2 query TTT -select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]], '+'), array_to_string(array_fill(3, [3, 2, 2]), '/\'); +select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]], '+'), array_to_string(array_repeat(array_repeat(array_repeat(3, 2), 2), 3), '/\'); ---- 11111 1+2+3+4+5+6 3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3 @@ -1240,7 +1730,7 @@ select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinali # cardinality scalar function #2 query II -select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_fill(3, array[3, 2, 3])); +select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_repeat(array_repeat(array_repeat(3, 3), 2), 3)); ---- 6 18 @@ -1430,31 +1920,7 @@ select array_remove_all(make_array([1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [28, 29, 30], [22, 23, 24]] [[19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24], [19, 20, 21], [25, 26, 27], [19, 20, 21], [22, 23, 24], [19, 20, 21], [19, 20, 21]] [[1, 2, 3], [4, 5, 6], [4, 5, 6], [10, 11, 12], [13, 14, 15], [10, 11, 12], [10, 11, 12], [19, 20, 21], [19, 20, 21], [19, 20, 21], [22, 23, 24]] [[28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]] -## trim_array - -# trim_array scalar function #1 -query ??? -select trim_array(make_array(1, 2, 3, 4, 5), 2), trim_array(['h', 'e', 'l', 'l', 'o'], 3), trim_array([1.0, 2.0, 3.0], 2); ----- -[1, 2, 3] [h, e] [1.0] - -# trim_array scalar function #2 -query ?? -select trim_array([[1, 2], [3, 4], [5, 6]], 2), trim_array(array_fill(4, [3, 4, 2]), 2); ----- -[[1, 2]] [[[4, 4], [4, 4], [4, 4], [4, 4]]] - -# trim_array scalar function #3 -query ? -select array_concat(trim_array(make_array(1, 2, 3), 3), make_array(4, 5), make_array()); ----- -[4, 5] - -# trim_array scalar function #4 -query ?? -select trim_array(make_array(), 0), trim_array(make_array(), 1) ----- -[] [] +## trim_array (deprecated) ## array_length (aliases: `list_length`) @@ -1477,10 +1943,10 @@ select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, NULL NULL 2 # array_length scalar function #4 -query IIII -select array_length(array_fill(3, [3, 2, 5]), 1), array_length(array_fill(3, [3, 2, 5]), 2), array_length(array_fill(3, [3, 2, 5]), 3), array_length(array_fill(3, [3, 2, 5]), 4); +query II +select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2); ---- -3 2 5 NULL +3 2 # array_length scalar function #5 query III @@ -1530,7 +1996,7 @@ select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), # array_dims scalar function #2 query ?? -select array_dims(array_fill(2, [1, 2, 3])), array_dims(array_fill(3, [2, 5, 4])); +select array_dims(array_repeat(array_repeat(array_repeat(2, 3), 2), 1)), array_dims(array_repeat(array_repeat(array_repeat(3, 4), 5), 2)); ---- [1, 2, 3] [2, 5, 4] @@ -1568,7 +2034,7 @@ select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])) # array_ndims scalar function #2 query II -select array_ndims(array_fill(1, [1, 2, 3])), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); ---- 3 21 @@ -1858,12 +2324,18 @@ select array_concat(column1, [7]) from arrays_values_v2; statement ok drop table values; +statement ok +drop table values_without_nulls; + statement ok drop table nested_arrays; statement ok drop table arrays; +statement ok +drop table slices; + statement ok drop table arrays_values; diff --git a/datafusion/core/tests/sqllogictests/test_files/explain.slt b/datafusion/core/tests/sqllogictests/test_files/explain.slt index bd3513550a4e..aa560961d2f3 100644 --- a/datafusion/core/tests/sqllogictests/test_files/explain.slt +++ b/datafusion/core/tests/sqllogictests/test_files/explain.slt @@ -159,7 +159,7 @@ query TT EXPLAIN INSERT INTO sink_table SELECT * FROM aggregate_test_100 ORDER by c1 ---- logical_plan -Dml: op=[Insert] table=[sink_table] +Dml: op=[Insert Into] table=[sink_table] --Projection: aggregate_test_100.c1 AS c1, aggregate_test_100.c2 AS c2, aggregate_test_100.c3 AS c3, aggregate_test_100.c4 AS c4, aggregate_test_100.c5 AS c5, aggregate_test_100.c6 AS c6, aggregate_test_100.c7 AS c7, aggregate_test_100.c8 AS c8, aggregate_test_100.c9 AS c9, aggregate_test_100.c10 AS c10, aggregate_test_100.c11 AS c11, aggregate_test_100.c12 AS c12, aggregate_test_100.c13 AS c13 ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] diff --git a/datafusion/core/tests/sqllogictests/test_files/groupby.slt b/datafusion/core/tests/sqllogictests/test_files/groupby.slt index 5db2ad928007..898abceefb41 100644 --- a/datafusion/core/tests/sqllogictests/test_files/groupby.slt +++ b/datafusion/core/tests/sqllogictests/test_files/groupby.slt @@ -1960,21 +1960,20 @@ SortPreservingMergeExec: [col0@0 ASC NULLS LAST] --SortExec: expr=[col0@0 ASC NULLS LAST] ----ProjectionExec: expr=[col0@0 as col0, LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]@3 as last_col1] ------AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)] ---------SortExec: expr=[col0@3 ASC NULLS LAST] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 ---------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallyOrdered -----------------SortExec: expr=[col0@3 ASC NULLS LAST] -------------------CoalesceBatchesExec: target_batch_size=8192 ---------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] -----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=4 ---------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------MemoryExec: partitions=1, partition_sizes=[3] -----------------------CoalesceBatchesExec: target_batch_size=8192 -------------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=4 ---------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -----------------------------MemoryExec: partitions=1, partition_sizes=[3] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1)], ordering_mode=PartiallyOrdered +--------------SortExec: expr=[col0@3 ASC NULLS LAST] +----------------CoalesceBatchesExec: target_batch_size=8192 +------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------MemoryExec: partitions=1, partition_sizes=[3] +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------MemoryExec: partitions=1, partition_sizes=[3] # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2213,6 +2212,21 @@ CREATE TABLE sales_global (zip_code INT, (1, 'TUR', 4, '2022-01-03 10:00:00'::timestamp, 'TRY', 100.0), (0, 'GRC', 4, '2022-01-03 10:00:00'::timestamp, 'EUR', 80.0) +# create a new table named exchange rates +statement ok +CREATE TABLE exchange_rates ( + sn INTEGER, + ts TIMESTAMP, + currency_from VARCHAR(3), + currency_to VARCHAR(3), + rate DECIMAL(10,2) +) as VALUES + (0, '2022-01-01 06:00:00'::timestamp, 'EUR', 'USD', 1.10), + (1, '2022-01-01 08:00:00'::timestamp, 'TRY', 'USD', 0.10), + (2, '2022-01-01 11:30:00'::timestamp, 'EUR', 'USD', 1.12), + (3, '2022-01-02 12:00:00'::timestamp, 'TRY', 'USD', 0.11), + (4, '2022-01-03 10:00:00'::timestamp, 'EUR', 'USD', 1.12) + # test_ordering_sensitive_aggregation # ordering sensitive requirement should add a SortExec in the final plan. To satisfy amount ASC # in the aggregation @@ -2689,13 +2703,12 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@2 as fv2] ------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@1 ASC NULLS LAST] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ---------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)], ordering_mode=None -----------------SortExec: expr=[ts@1 ASC NULLS LAST] -------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ---------------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------------SortExec: expr=[ts@1 ASC NULLS LAST] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2727,13 +2740,12 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as fv2] ------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] ---------SortExec: expr=[ts@1 ASC NULLS LAST] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ---------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)], ordering_mode=None -----------------SortExec: expr=[ts@1 ASC NULLS LAST] -------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ---------------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] +--------------SortExec: expr=[ts@1 ASC NULLS LAST] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------MemoryExec: partitions=1, partition_sizes=[1] query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, @@ -2760,8 +2772,8 @@ Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS physical_plan ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv2] --AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] -----SortPreservingMergeExec: [ts@0 ASC NULLS LAST] -------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)], ordering_mode=None +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] --------SortExec: expr=[ts@0 ASC NULLS LAST] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2787,8 +2799,8 @@ Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS physical_plan ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv2] --AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] -----SortPreservingMergeExec: [ts@0 ASC NULLS LAST] -------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)], ordering_mode=None +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount), FIRST_VALUE(sales_global.amount)] --------SortExec: expr=[ts@0 ASC NULLS LAST] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2812,8 +2824,8 @@ Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS L physical_plan ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as array_agg1] --AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] -----SortPreservingMergeExec: [ts@0 ASC NULLS LAST] -------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount)], ordering_mode=None +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] --------SortExec: expr=[ts@0 ASC NULLS LAST] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2836,8 +2848,8 @@ Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS physical_plan ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@0 as array_agg1] --AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] -----SortPreservingMergeExec: [ts@0 DESC] -------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount)], ordering_mode=None +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] --------SortExec: expr=[ts@0 DESC] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2860,8 +2872,8 @@ Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NUL physical_plan ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@0 as array_agg1] --AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] -----SortPreservingMergeExec: [amount@0 ASC NULLS LAST] -------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount)], ordering_mode=None +----CoalescePartitionsExec +------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount)] --------SortExec: expr=[amount@0 ASC NULLS LAST] ----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ------------MemoryExec: partitions=1, partition_sizes=[1] @@ -2889,13 +2901,12 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as array_agg1] ------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)] ---------SortExec: expr=[amount@1 ASC NULLS LAST] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ---------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)], ordering_mode=None -----------------SortExec: expr=[amount@1 ASC NULLS LAST] -------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ---------------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount)] +--------------SortExec: expr=[amount@1 ASC NULLS LAST] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------MemoryExec: partitions=1, partition_sizes=[1] query T? SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS array_agg1 @@ -2926,13 +2937,12 @@ SortPreservingMergeExec: [country@0 ASC NULLS LAST] --SortExec: expr=[country@0 ASC NULLS LAST] ----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] ------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] ---------SortExec: expr=[amount@1 DESC] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 ---------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)], ordering_mode=None -----------------SortExec: expr=[amount@1 DESC] -------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 ---------------------MemoryExec: partitions=1, partition_sizes=[1] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 +------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount), LAST_VALUE(sales_global.amount), LAST_VALUE(sales_global.amount)] +--------------SortExec: expr=[amount@1 DESC] +----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +------------------MemoryExec: partitions=1, partition_sizes=[1] query T?RR SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, @@ -2946,6 +2956,39 @@ FRA [200.0, 50.0] 50 50 GRC [80.0, 30.0] 30 30 TUR [100.0, 75.0] 75 75 +# make sure that query below runs in multi partitions +statement ok +set datafusion.execution.target_partitions = 8; + +query ? +SELECT ARRAY_AGG(e.rate ORDER BY e.sn) +FROM sales_global AS s +JOIN exchange_rates AS e +ON s.currency = e.currency_from AND + e.currency_to = 'USD' AND + s.ts >= e.ts +GROUP BY s.sn +ORDER BY s.sn; +---- +[1.10] +[1.10] +[0.10] +[1.10, 1.12] +[1.10, 0.10, 1.12, 0.11, 1.12] + + +query I +SELECT FIRST_VALUE(C order by c ASC) as first_c +FROM multiple_ordered_table +GROUP BY d +ORDER BY first_c +---- +0 +1 +4 +9 +15 + query ITIPTR SELECT s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount ORDER BY e.sn) AS last_rate FROM sales_global AS s diff --git a/datafusion/core/tests/sqllogictests/test_files/information_schema_columns.slt b/datafusion/core/tests/sqllogictests/test_files/information_schema_columns.slt index fcb653cedd16..7cf845c16d73 100644 --- a/datafusion/core/tests/sqllogictests/test_files/information_schema_columns.slt +++ b/datafusion/core/tests/sqllogictests/test_files/information_schema_columns.slt @@ -50,4 +50,4 @@ statement ok drop table t1 statement ok -drop table t2 \ No newline at end of file +drop table t2 diff --git a/datafusion/core/tests/sqllogictests/test_files/insert.slt b/datafusion/core/tests/sqllogictests/test_files/insert.slt index 90a33bd1c5f7..e42d2ef0592d 100644 --- a/datafusion/core/tests/sqllogictests/test_files/insert.slt +++ b/datafusion/core/tests/sqllogictests/test_files/insert.slt @@ -57,7 +57,7 @@ FROM aggregate_test_100 ORDER by c1 ---- logical_plan -Dml: op=[Insert] table=[table_without_values] +Dml: op=[Insert Into] table=[table_without_values] --Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 @@ -120,20 +120,19 @@ COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWI FROM aggregate_test_100 ---- logical_plan -Dml: op=[Insert] table=[table_without_values] +Dml: op=[Insert Into] table=[table_without_values] --Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 ----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] ------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan InsertExec: sink=MemoryTable (partitions=1) ---CoalescePartitionsExec -----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] ---------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] -----------CoalesceBatchesExec: target_batch_size=8192 -------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 ---------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true +--ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] +----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }], mode=[Sorted] +------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 +------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true @@ -168,7 +167,7 @@ FROM aggregate_test_100 ORDER BY c1 ---- logical_plan -Dml: op=[Insert] table=[table_without_values] +Dml: op=[Insert Into] table=[table_without_values] --Projection: a1 AS a1, a2 AS a2 ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a2, aggregate_test_100.c1 @@ -212,7 +211,7 @@ query TT explain insert into table_without_values select c1 from aggregate_test_100 order by c1; ---- logical_plan -Dml: op=[Insert] table=[table_without_values] +Dml: op=[Insert Into] table=[table_without_values] --Projection: aggregate_test_100.c1 AS c1 ----Sort: aggregate_test_100.c1 ASC NULLS LAST ------TableScan: aggregate_test_100 projection=[c1] diff --git a/datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt b/datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt index fcc6d665c6da..daeb7aad9aa5 100644 --- a/datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt +++ b/datafusion/core/tests/sqllogictests/test_files/join_disable_repartition_joins.slt @@ -63,18 +63,18 @@ GlobalLimitExec: skip=0, fetch=5 --------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true # preserve_inner_join -query III nosort -SELECT t1.a, t2.a as a2, t2.b +query IIII nosort +SELECT t1.a, t1.b, t1.c, t2.a as a2 FROM annotated_data as t1 INNER JOIN annotated_data as t2 ON t1.d = t2.d ORDER BY a2, t2.b LIMIT 5 ---- -1 0 0 -1 0 0 -1 0 0 -1 0 0 -1 0 0 +0 0 0 0 +0 0 2 0 +0 0 3 0 +0 0 6 0 +0 0 20 0 query TT EXPLAIN SELECT t2.a as a2, t2.b diff --git a/datafusion/core/tests/sqllogictests/test_files/math.slt b/datafusion/core/tests/sqllogictests/test_files/math.slt index 152e8b78bdfa..fc27333ec0af 100644 --- a/datafusion/core/tests/sqllogictests/test_files/math.slt +++ b/datafusion/core/tests/sqllogictests/test_files/math.slt @@ -93,3 +93,9 @@ query RRRRRRR SELECT atan2(2.0, 1.0), atan2(-2.0, 1.0), atan2(2.0, -1.0), atan2(-2.0, -1.0), atan2(NULL, 1.0), atan2(2.0, NULL), atan2(NULL, NULL); ---- 1.107148717794 -1.107148717794 2.034443935796 -2.034443935796 NULL NULL NULL + +# nanvl +query RRR +SELECT nanvl(asin(10), 1.0), nanvl(1.0, 2.0), nanvl(asin(10), asin(10)) +---- +1 1 NaN \ No newline at end of file diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt b/datafusion/core/tests/sqllogictests/test_files/scalar.slt index d5ce7737fba0..80f5bd6c9d78 100644 --- a/datafusion/core/tests/sqllogictests/test_files/scalar.slt +++ b/datafusion/core/tests/sqllogictests/test_files/scalar.slt @@ -660,6 +660,41 @@ select round(log2(a), 5), round(log2(b), 5), round(log2(c), 5) from signed_integ NaN 13.28771 NaN NaN 6.64386 NaN +## nanvl + +# nanvl scalar function +query RRR rowsort +select nanvl(0, 1), nanvl(asin(10), 2), nanvl(3, asin(10)); +---- +0 2 3 + +# nanvl scalar nulls +query R rowsort +select nanvl(null, 64); +---- +NULL + +# nanvl scalar nulls #1 +query R rowsort +select nanvl(2, null); +---- +NULL + +# nanvl scalar nulls #2 +query R rowsort +select nanvl(null, null); +---- +NULL + +# nanvl with columns (round is needed to normalize the outputs of different operating systems) +query RRR rowsort +select round(nanvl(asin(f + a), 2), 5), round(nanvl(asin(b + c), 3), 5), round(nanvl(asin(d + e), 4), 5) from small_floats; +---- +0.7754 1.11977 -0.9273 +2 -0.20136 0.7754 +2 -1.11977 4 +NULL NULL NULL + ## pi # pi scalar function diff --git a/datafusion/core/tests/sqllogictests/test_files/struct.slt b/datafusion/core/tests/sqllogictests/test_files/struct.slt new file mode 100644 index 000000000000..2629b6b038a3 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/struct.slt @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Struct Expressions Tests +############# + +statement ok +CREATE TABLE values( + a INT, + b FLOAT, + c VARCHAR +) AS VALUES + (1, 1.1, 'a'), + (2, 2.2, 'b'), + (3, 3.3, 'c') +; + +# struct[i] +query IRT +select struct(1, 3.14, 'h')['c0'], struct(3, 2.55, 'b')['c1'], struct(2, 6.43, 'a')['c2']; +---- +1 2.55 a + +# struct[i] with columns +query R +select struct(a, b, c)['c1'] from values; +---- +1.1 +2.2 +3.3 + +# struct scalar function #1 +query ? +select struct(1, 3.14, 'e'); +---- +{c0: 1, c1: 3.14, c2: e} + +# struct scalar function with columns #1 +query ? +select struct(a, b, c) from values; +---- +{c0: 1, c1: 1.1, c2: a} +{c0: 2, c1: 2.2, c2: b} +{c0: 3, c1: 3.3, c2: c} + +statement ok +drop table values; \ No newline at end of file diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index a860e15fa2f5..c8478499365e 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -25,7 +25,7 @@ use std::{ use datafusion_common::{config::ConfigOptions, Result, ScalarValue}; /// Configuration options for Execution context -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct SessionConfig { /// Configuration options options: ConfigOptions, diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 011cd72cbb9a..f8fc9fcdbbbb 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -77,7 +77,9 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { fn reserved(&self) -> usize; } -/// A memory consumer that can be tracked by [`MemoryReservation`] in a [`MemoryPool`] +/// A memory consumer that can be tracked by [`MemoryReservation`] in +/// a [`MemoryPool`]. All allocations are registered to a particular +/// `MemoryConsumer`; #[derive(Debug)] pub struct MemoryConsumer { name: String, @@ -113,20 +115,40 @@ impl MemoryConsumer { pub fn register(self, pool: &Arc) -> MemoryReservation { pool.register(&self); MemoryReservation { - consumer: self, + registration: Arc::new(SharedRegistration { + pool: Arc::clone(pool), + consumer: self, + }), size: 0, - policy: Arc::clone(pool), } } } -/// A [`MemoryReservation`] tracks a reservation of memory in a [`MemoryPool`] -/// that is freed back to the pool on drop +/// A registration of a [`MemoryConsumer`] with a [`MemoryPool`]. +/// +/// Calls [`MemoryPool::unregister`] on drop to return any memory to +/// the underlying pool. #[derive(Debug)] -pub struct MemoryReservation { +struct SharedRegistration { + pool: Arc, consumer: MemoryConsumer, +} + +impl Drop for SharedRegistration { + fn drop(&mut self) { + self.pool.unregister(&self.consumer); + } +} + +/// A [`MemoryReservation`] tracks an individual reservation of a +/// number of bytes of memory in a [`MemoryPool`] that is freed back +/// to the pool on drop. +/// +/// The reservation can be grown or shrunk over time. +#[derive(Debug)] +pub struct MemoryReservation { + registration: Arc, size: usize, - policy: Arc, } impl MemoryReservation { @@ -135,7 +157,8 @@ impl MemoryReservation { self.size } - /// Frees all bytes from this reservation returning the number of bytes freed + /// Frees all bytes from this reservation back to the underlying + /// pool, returning the number of bytes freed. pub fn free(&mut self) -> usize { let size = self.size; if size != 0 { @@ -151,7 +174,7 @@ impl MemoryReservation { /// Panics if `capacity` exceeds [`Self::size`] pub fn shrink(&mut self, capacity: usize) { let new_size = self.size.checked_sub(capacity).unwrap(); - self.policy.shrink(self, capacity); + self.registration.pool.shrink(self, capacity); self.size = new_size } @@ -176,22 +199,55 @@ impl MemoryReservation { /// Increase the size of this reservation by `capacity` bytes pub fn grow(&mut self, capacity: usize) { - self.policy.grow(self, capacity); + self.registration.pool.grow(self, capacity); self.size += capacity; } - /// Try to increase the size of this reservation by `capacity` bytes + /// Try to increase the size of this reservation by `capacity` + /// bytes, returning error if there is insufficient capacity left + /// in the pool. pub fn try_grow(&mut self, capacity: usize) -> Result<()> { - self.policy.try_grow(self, capacity)?; + self.registration.pool.try_grow(self, capacity)?; self.size += capacity; Ok(()) } + + /// Splits off `capacity` bytes from this [`MemoryReservation`] + /// into a new [`MemoryReservation`] with the same + /// [`MemoryConsumer`]. + /// + /// This can be useful to free part of this reservation with RAAI + /// style dropping + /// + /// # Panics + /// + /// Panics if `capacity` exceeds [`Self::size`] + pub fn split(&mut self, capacity: usize) -> MemoryReservation { + self.size = self.size.checked_sub(capacity).unwrap(); + Self { + size: capacity, + registration: self.registration.clone(), + } + } + + /// Returns a new empty [`MemoryReservation`] with the same [`MemoryConsumer`] + pub fn new_empty(&self) -> Self { + Self { + size: 0, + registration: self.registration.clone(), + } + } + + /// Splits off all the bytes from this [`MemoryReservation`] into + /// a new [`MemoryReservation`] with the same [`MemoryConsumer`] + pub fn take(&mut self) -> MemoryReservation { + self.split(self.size) + } } impl Drop for MemoryReservation { fn drop(&mut self) { self.free(); - self.policy.unregister(&self.consumer); } } @@ -251,4 +307,59 @@ mod tests { a2.try_grow(25).unwrap(); assert_eq!(pool.reserved(), 25); } + + #[test] + fn test_split() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(20).unwrap(); + assert_eq!(r1.size(), 20); + assert_eq!(pool.reserved(), 20); + + // take 5 from r1, should still have same reservation split + let r2 = r1.split(5); + assert_eq!(r1.size(), 15); + assert_eq!(r2.size(), 5); + assert_eq!(pool.reserved(), 20); + + // dropping r1 frees 15 but retains 5 as they have the same consumer + drop(r1); + assert_eq!(r2.size(), 5); + assert_eq!(pool.reserved(), 5); + } + + #[test] + fn test_new_empty() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(20).unwrap(); + let mut r2 = r1.new_empty(); + r2.try_grow(5).unwrap(); + + assert_eq!(r1.size(), 20); + assert_eq!(r2.size(), 5); + assert_eq!(pool.reserved(), 25); + } + + #[test] + fn test_take() { + let pool = Arc::new(GreedyMemoryPool::new(50)) as _; + let mut r1 = MemoryConsumer::new("r1").register(&pool); + + r1.try_grow(20).unwrap(); + let mut r2 = r1.take(); + r2.try_grow(5).unwrap(); + + assert_eq!(r1.size(), 0); + assert_eq!(r2.size(), 25); + assert_eq!(pool.reserved(), 25); + + // r1 can still grow again + r1.try_grow(3).unwrap(); + assert_eq!(r1.size(), 3); + assert_eq!(r2.size(), 25); + assert_eq!(pool.reserved(), 28); + } } diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index 7b68a86244b7..1242ce025ca2 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -84,7 +84,11 @@ impl MemoryPool for GreedyMemoryPool { (new_used <= self.pool_size).then_some(new_used) }) .map_err(|used| { - insufficient_capacity_err(reservation, additional, self.pool_size - used) + insufficient_capacity_err( + reservation, + additional, + self.pool_size.saturating_sub(used), + ) })?; Ok(()) } @@ -159,13 +163,14 @@ impl MemoryPool for FairSpillPool { fn unregister(&self, consumer: &MemoryConsumer) { if consumer.can_spill { - self.state.lock().num_spill -= 1; + let mut state = self.state.lock(); + state.num_spill = state.num_spill.checked_sub(1).unwrap(); } } fn grow(&self, reservation: &MemoryReservation, additional: usize) { let mut state = self.state.lock(); - match reservation.consumer.can_spill { + match reservation.registration.consumer.can_spill { true => state.spillable += additional, false => state.unspillable += additional, } @@ -173,7 +178,7 @@ impl MemoryPool for FairSpillPool { fn shrink(&self, reservation: &MemoryReservation, shrink: usize) { let mut state = self.state.lock(); - match reservation.consumer.can_spill { + match reservation.registration.consumer.can_spill { true => state.spillable -= shrink, false => state.unspillable -= shrink, } @@ -182,7 +187,7 @@ impl MemoryPool for FairSpillPool { fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { let mut state = self.state.lock(); - match reservation.consumer.can_spill { + match reservation.registration.consumer.can_spill { true => { // The total amount of memory available to spilling consumers let spill_available = self.pool_size.saturating_sub(state.unspillable); @@ -230,7 +235,7 @@ fn insufficient_capacity_err( additional: usize, available: usize, ) -> DataFusionError { - DataFusionError::ResourcesExhausted(format!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.consumer.name, reservation.size, available)) + DataFusionError::ResourcesExhausted(format!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.registration.consumer.name, reservation.size, available)) } #[cfg(test)] @@ -247,7 +252,7 @@ mod tests { r1.grow(2000); assert_eq!(pool.reserved(), 2000); - let mut r2 = MemoryConsumer::new("s1") + let mut r2 = MemoryConsumer::new("r2") .with_can_spill(true) .register(&pool); // Can grow beyond capacity of pool @@ -256,10 +261,10 @@ mod tests { assert_eq!(pool.reserved(), 4000); let err = r2.try_grow(1).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for s1 with 2000 bytes already allocated - maximum available is 0"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); let err = r2.try_grow(1).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for s1 with 2000 bytes already allocated - maximum available is 0"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); r1.shrink(1990); r2.shrink(2000); @@ -269,7 +274,7 @@ mod tests { r1.try_grow(10).unwrap(); assert_eq!(pool.reserved(), 20); - // Can grow a2 to 80 as only spilling consumer + // Can grow r2 to 80 as only spilling consumer r2.try_grow(80).unwrap(); assert_eq!(pool.reserved(), 100); @@ -279,19 +284,19 @@ mod tests { assert_eq!(r2.size(), 10); assert_eq!(pool.reserved(), 30); - let mut r3 = MemoryConsumer::new("s2") + let mut r3 = MemoryConsumer::new("r3") .with_can_spill(true) .register(&pool); let err = r3.try_grow(70).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for s2 with 0 bytes already allocated - maximum available is 40"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); - //Shrinking a2 to zero doesn't allow a3 to allocate more than 45 + //Shrinking r2 to zero doesn't allow a3 to allocate more than 45 r2.free(); let err = r3.try_grow(70).unwrap_err().to_string(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for s2 with 0 bytes already allocated - maximum available is 40"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); - // But dropping a2 does + // But dropping r2 does drop(r2); assert_eq!(pool.reserved(), 20); r3.try_grow(80).unwrap(); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index add1262237a8..876f746d05af 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -19,7 +19,7 @@ use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use std::sync::Arc; use std::{fmt, str::FromStr}; use strum_macros::EnumIter; @@ -61,6 +61,24 @@ pub enum AggregateFunction { CovariancePop, /// Correlation Correlation, + /// Slope from linear regression + RegrSlope, + /// Intercept from linear regression + RegrIntercept, + /// Number of input rows in which both expressions are not null + RegrCount, + /// R-squared value from linear regression + RegrR2, + /// Average of the independent variable + RegrAvgx, + /// Average of the dependent variable + RegrAvgy, + /// Sum of squares of the independent variable + RegrSXX, + /// Sum of squares of the dependent variable + RegrSYY, + /// Sum of products of pairs of numbers + RegrSXY, /// Approximate continuous percentile function ApproxPercentileCont, /// Approximate continuous percentile function with weight @@ -102,6 +120,15 @@ impl AggregateFunction { Covariance => "COVARIANCE", CovariancePop => "COVARIANCE_POP", Correlation => "CORRELATION", + RegrSlope => "REGR_SLOPE", + RegrIntercept => "REGR_INTERCEPT", + RegrCount => "REGR_COUNT", + RegrR2 => "REGR_R2", + RegrAvgx => "REGR_AVGX", + RegrAvgy => "REGR_AVGY", + RegrSXX => "REGR_SXX", + RegrSYY => "REGR_SYY", + RegrSXY => "REGR_SXY", ApproxPercentileCont => "APPROX_PERCENTILE_CONT", ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", ApproxMedian => "APPROX_MEDIAN", @@ -152,6 +179,15 @@ impl FromStr for AggregateFunction { "var" => AggregateFunction::Variance, "var_pop" => AggregateFunction::VariancePop, "var_samp" => AggregateFunction::Variance, + "regr_slope" => AggregateFunction::RegrSlope, + "regr_intercept" => AggregateFunction::RegrIntercept, + "regr_count" => AggregateFunction::RegrCount, + "regr_r2" => AggregateFunction::RegrR2, + "regr_avgx" => AggregateFunction::RegrAvgx, + "regr_avgy" => AggregateFunction::RegrAvgy, + "regr_sxx" => AggregateFunction::RegrSXX, + "regr_syy" => AggregateFunction::RegrSYY, + "regr_sxy" => AggregateFunction::RegrSXY, // approximate "approx_distinct" => AggregateFunction::ApproxDistinct, "approx_median" => AggregateFunction::ApproxMedian, @@ -162,9 +198,7 @@ impl FromStr for AggregateFunction { // other "grouping" => AggregateFunction::Grouping, _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in function named {name}" - ))); + return plan_err!("There is no built-in function named {name}"); } }) } @@ -228,6 +262,15 @@ impl AggregateFunction { } AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]), + AggregateFunction::RegrSlope + | AggregateFunction::RegrIntercept + | AggregateFunction::RegrCount + | AggregateFunction::RegrR2 + | AggregateFunction::RegrAvgx + | AggregateFunction::RegrAvgy + | AggregateFunction::RegrSXX + | AggregateFunction::RegrSYY + | AggregateFunction::RegrSXY => Ok(DataType::Float64), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( "item", @@ -311,10 +354,18 @@ impl AggregateFunction { | AggregateFunction::LastValue => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } - AggregateFunction::Covariance | AggregateFunction::CovariancePop => { - Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) - } - AggregateFunction::Correlation => { + AggregateFunction::Covariance + | AggregateFunction::CovariancePop + | AggregateFunction::Correlation + | AggregateFunction::RegrSlope + | AggregateFunction::RegrIntercept + | AggregateFunction::RegrCount + | AggregateFunction::RegrR2 + | AggregateFunction::RegrAvgx + | AggregateFunction::RegrAvgy + | AggregateFunction::RegrSXX + | AggregateFunction::RegrSYY + | AggregateFunction::RegrSXY => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::ApproxPercentileCont => { diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 33db0f9eb1a4..061d0689cd97 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -23,7 +23,7 @@ use crate::{ conditional_expressions, struct_expressions, Signature, TypeSignature, Volatility, }; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use std::collections::HashMap; use std::fmt; use std::str::FromStr; @@ -89,6 +89,8 @@ pub enum BuiltinScalarFunction { Log10, /// log2 Log2, + /// nanvl + Nanvl, /// pi Pi, /// power @@ -127,8 +129,8 @@ pub enum BuiltinScalarFunction { ArrayHasAny, /// array_dims ArrayDims, - /// array_fill - ArrayFill, + /// array_element + ArrayElement, /// array_length ArrayLength, /// array_ndims @@ -145,20 +147,26 @@ pub enum BuiltinScalarFunction { ArrayRemoveN, /// array_remove_all ArrayRemoveAll, + /// array_repeat + ArrayRepeat, /// array_replace ArrayReplace, /// array_replace_n ArrayReplaceN, /// array_replace_all ArrayReplaceAll, + /// array_slice + ArraySlice, /// array_to_string ArrayToString, /// cardinality Cardinality, /// construct an array from columns MakeArray, - /// trim_array - TrimArray, + + // struct functions + /// struct + Struct, // string functions /// ascii @@ -257,8 +265,6 @@ pub enum BuiltinScalarFunction { Uuid, /// regexp_match RegexpMatch, - /// struct - Struct, /// arrow_typeof ArrowTypeof, } @@ -328,6 +334,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Log => Volatility::Immutable, BuiltinScalarFunction::Log10 => Volatility::Immutable, BuiltinScalarFunction::Log2 => Volatility::Immutable, + BuiltinScalarFunction::Nanvl => Volatility::Immutable, BuiltinScalarFunction::Pi => Volatility::Immutable, BuiltinScalarFunction::Power => Volatility::Immutable, BuiltinScalarFunction::Round => Volatility::Immutable, @@ -346,22 +353,23 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable, BuiltinScalarFunction::ArrayHas => Volatility::Immutable, BuiltinScalarFunction::ArrayDims => Volatility::Immutable, - BuiltinScalarFunction::ArrayFill => Volatility::Immutable, + BuiltinScalarFunction::ArrayElement => Volatility::Immutable, BuiltinScalarFunction::ArrayLength => Volatility::Immutable, BuiltinScalarFunction::ArrayNdims => Volatility::Immutable, BuiltinScalarFunction::ArrayPosition => Volatility::Immutable, BuiltinScalarFunction::ArrayPositions => Volatility::Immutable, BuiltinScalarFunction::ArrayPrepend => Volatility::Immutable, + BuiltinScalarFunction::ArrayRepeat => Volatility::Immutable, BuiltinScalarFunction::ArrayRemove => Volatility::Immutable, BuiltinScalarFunction::ArrayRemoveN => Volatility::Immutable, BuiltinScalarFunction::ArrayRemoveAll => Volatility::Immutable, BuiltinScalarFunction::ArrayReplace => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, + BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, - BuiltinScalarFunction::TrimArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, BuiltinScalarFunction::Btrim => Volatility::Immutable, @@ -480,9 +488,7 @@ impl BuiltinScalarFunction { // or the execution panics. if input_expr_types.is_empty() && !self.supports_zero_argument() { - return Err(DataFusionError::Plan( - self.generate_signature_error_msg(input_expr_types), - )); + return plan_err!("{}", self.generate_signature_error_msg(input_expr_types)); } // verify that this is a valid set of data types for this function @@ -524,11 +530,12 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDims => { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } - BuiltinScalarFunction::ArrayFill => Ok(List(Arc::new(Field::new( - "item", - input_expr_types[1].clone(), - true, - )))), + BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { + List(field) => Ok(field.data_type().clone()), + _ => Err(DataFusionError::Internal(format!( + "The {self} function can only accept list as the first argument" + ))), + }, BuiltinScalarFunction::ArrayLength => Ok(UInt64), BuiltinScalarFunction::ArrayNdims => Ok(UInt64), BuiltinScalarFunction::ArrayPosition => Ok(UInt64), @@ -536,12 +543,18 @@ impl BuiltinScalarFunction { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } BuiltinScalarFunction::ArrayPrepend => Ok(input_expr_types[1].clone()), + BuiltinScalarFunction::ArrayRepeat => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), BuiltinScalarFunction::ArrayRemove => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayRemoveN => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayRemoveAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayReplace => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { @@ -558,16 +571,6 @@ impl BuiltinScalarFunction { Ok(List(Arc::new(Field::new("item", expr_type, true)))) } }, - BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { - List(field) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => Err(DataFusionError::Internal(format!( - "The {self} function can only accept list as the first argument" - ))), - }, BuiltinScalarFunction::Ascii => Ok(Int32), BuiltinScalarFunction::BitLength => { utf8_to_int_type(&input_expr_types[0], "bit_length") @@ -760,6 +763,11 @@ impl BuiltinScalarFunction { _ => Ok(Float64), }, + BuiltinScalarFunction::Nanvl => match &input_expr_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + }, + BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), BuiltinScalarFunction::Abs @@ -813,7 +821,7 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::ArrayHasAny | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), - BuiltinScalarFunction::ArrayFill => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayLength => { Signature::variadic_any(self.volatility()) } @@ -823,6 +831,7 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayRemoveAll => Signature::any(2, self.volatility()), @@ -831,6 +840,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => { Signature::any(3, self.volatility()) } + BuiltinScalarFunction::ArraySlice => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayToString => { Signature::variadic_any(self.volatility()) } @@ -838,7 +848,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::MakeArray => { Signature::variadic_any(self.volatility()) } - BuiltinScalarFunction::TrimArray => Signature::any(2, self.volatility()), BuiltinScalarFunction::Struct => Signature::variadic( struct_expressions::SUPPORTED_STRUCT_TYPES.to_vec(), self.volatility(), @@ -1120,6 +1129,10 @@ impl BuiltinScalarFunction { ], self.volatility(), ), + BuiltinScalarFunction::Nanvl => Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + self.volatility(), + ), BuiltinScalarFunction::Factorial => { Signature::uniform(1, vec![Int64], self.volatility()) } @@ -1193,6 +1206,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Log => &["log"], BuiltinScalarFunction::Log10 => &["log10"], BuiltinScalarFunction::Log2 => &["log2"], + BuiltinScalarFunction::Nanvl => &["nanvl"], BuiltinScalarFunction::Pi => &["pi"], BuiltinScalarFunction::Power => &["power", "pow"], BuiltinScalarFunction::Radians => &["radians"], @@ -1272,7 +1286,6 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Decode => &["decode"], // other functions - BuiltinScalarFunction::Struct => &["struct"], BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], // array functions @@ -1286,12 +1299,17 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { &["array_concat", "array_cat", "list_concat", "list_cat"] } BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"], + BuiltinScalarFunction::ArrayElement => &[ + "array_element", + "array_extract", + "list_element", + "list_extract", + ], BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], BuiltinScalarFunction::ArrayHas => { &["array_has", "list_has", "array_contains", "list_contains"] } - BuiltinScalarFunction::ArrayFill => &["array_fill"], BuiltinScalarFunction::ArrayLength => &["array_length", "list_length"], BuiltinScalarFunction::ArrayNdims => &["array_ndims", "list_ndims"], BuiltinScalarFunction::ArrayPosition => &[ @@ -1307,6 +1325,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { "array_push_front", "list_push_front", ], + BuiltinScalarFunction::ArrayRepeat => &["array_repeat", "list_repeat"], BuiltinScalarFunction::ArrayRemove => &["array_remove", "list_remove"], BuiltinScalarFunction::ArrayRemoveN => &["array_remove_n", "list_remove_n"], BuiltinScalarFunction::ArrayRemoveAll => &["array_remove_all", "list_remove_all"], @@ -1315,6 +1334,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::ArrayReplaceAll => { &["array_replace_all", "list_replace_all"] } + BuiltinScalarFunction::ArraySlice => &["array_slice", "list_slice"], BuiltinScalarFunction::ArrayToString => &[ "array_to_string", "list_to_string", @@ -1323,7 +1343,9 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { ], BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], - BuiltinScalarFunction::TrimArray => &["trim_array"], + + // struct functions + BuiltinScalarFunction::Struct => &["struct"], } } @@ -1340,9 +1362,7 @@ impl FromStr for BuiltinScalarFunction { if let Some(func) = NAME_TO_FUNCTION.get(name) { Ok(*func) } else { - Err(DataFusionError::Plan(format!( - "There is no built-in function named {name}" - ))) + plan_err!("There is no built-in function named {name}") } } } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index aba44061387a..c31bd04eafa0 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -19,7 +19,7 @@ use crate::expr::Case; use crate::{expr_schema::ExprSchemable, Expr}; use arrow::datatypes::DataType; -use datafusion_common::{DFSchema, DataFusionError, Result}; +use datafusion_common::{plan_err, DFSchema, DataFusionError, Result}; use std::collections::HashSet; /// Currently supported types by the coalesce function. @@ -102,9 +102,9 @@ impl CaseBuilder { } else { let unique_types: HashSet<&DataType> = then_types.iter().collect(); if unique_types.len() != 1 { - return Err(DataFusionError::Plan(format!( + return plan_err!( "CASE expression 'then' values had multiple data types: {unique_types:?}" - ))); + ); } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 485b91141f7f..a0cfb6e1b00a 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -115,7 +115,8 @@ pub enum Expr { IsNotUnknown(Box), /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), - /// Returns the field of a [`arrow::array::ListArray`] or [`arrow::array::StructArray`] by key + /// Returns the field of a [`arrow::array::ListArray`] or + /// [`arrow::array::StructArray`] by index or range GetIndexedField(GetIndexedField), /// Whether an expression is between a given range. Between(Between), @@ -358,19 +359,32 @@ impl ScalarUDF { } } -/// Returns the field of a [`arrow::array::ListArray`] or [`arrow::array::StructArray`] by key +/// Access a sub field of a nested type, such as `Field` or `List` +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum GetFieldAccess { + /// Named field, for example `struct["name"]` + NamedStructField { name: ScalarValue }, + /// Single list index, for example: `list[i]` + ListIndex { key: Box }, + /// List range, for example `list[i:j]` + ListRange { start: Box, stop: Box }, +} + +/// Returns the field of a [`arrow::array::ListArray`] or +/// [`arrow::array::StructArray`] by `key`. See [`GetFieldAccess`] for +/// details. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct GetIndexedField { - /// the expression to take the field from + /// The expression to take the field from pub expr: Box, /// The name of the field to take - pub key: ScalarValue, + pub field: GetFieldAccess, } impl GetIndexedField { /// Create a new GetIndexedField expression - pub fn new(expr: Box, key: ScalarValue) -> Self { - Self { expr, key } + pub fn new(expr: Box, field: GetFieldAccess) -> Self { + Self { expr, field } } } @@ -912,10 +926,94 @@ impl Expr { )) } + /// Return access to the named field. Example `expr["name"]` + /// + /// ## Access field "my_field" from column "c1" + /// + /// For example if column "c1" holds documents like this + /// + /// ```json + /// { + /// "my_field": 123.34, + /// "other_field": "Boston", + /// } + /// ``` + /// + /// You can access column "my_field" with + /// + /// ``` + /// # use datafusion_expr::{col}; + /// let expr = col("c1") + /// .field("my_field"); + /// assert_eq!(expr.display_name().unwrap(), "c1[my_field]"); + /// ``` + pub fn field(self, name: impl Into) -> Self { + Expr::GetIndexedField(GetIndexedField { + expr: Box::new(self), + field: GetFieldAccess::NamedStructField { + name: ScalarValue::Utf8(Some(name.into())), + }, + }) + } + + /// Return access to the element field. Example `expr["name"]` + /// + /// ## Example Access element 2 from column "c1" + /// + /// For example if column "c1" holds documents like this + /// + /// ```json + /// [10, 20, 30, 40] + /// ``` + /// + /// You can access the value "30" with + /// + /// ``` + /// # use datafusion_expr::{lit, col, Expr}; + /// let expr = col("c1") + /// .index(lit(3)); + /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(3)]"); + /// ``` + pub fn index(self, key: Expr) -> Self { + Expr::GetIndexedField(GetIndexedField { + expr: Box::new(self), + field: GetFieldAccess::ListIndex { key: Box::new(key) }, + }) + } + + /// Return elements between `1` based `start` and `stop`, for + /// example `expr[1:3]` + /// + /// ## Example: Access element 2, 3, 4 from column "c1" + /// + /// For example if column "c1" holds documents like this + /// + /// ```json + /// [10, 20, 30, 40] + /// ``` + /// + /// You can access the value `[20, 30, 40]` with + /// + /// ``` + /// # use datafusion_expr::{lit, col}; + /// let expr = col("c1") + /// .range(lit(2), lit(4)); + /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4)]"); + /// ``` + pub fn range(self, start: Expr, stop: Expr) -> Self { + Expr::GetIndexedField(GetIndexedField { + expr: Box::new(self), + field: GetFieldAccess::ListRange { + start: Box::new(start), + stop: Box::new(stop), + }, + }) + } + pub fn try_into_col(&self) -> Result { match self { Expr::Column(it) => Ok(it.clone()), - _ => plan_err!(format!("Could not coerce '{self}' into Column!")), + _ => plan_err!("Could not coerce '{self}' into Column!"), } } @@ -1139,9 +1237,15 @@ impl fmt::Display for Expr { } Expr::Wildcard => write!(f, "*"), Expr::QualifiedWildcard { qualifier } => write!(f, "{qualifier}.*"), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - write!(f, "({expr})[{key}]") - } + Expr::GetIndexedField(GetIndexedField { field, expr }) => match field { + GetFieldAccess::NamedStructField { name } => { + write!(f, "({expr})[{name}]") + } + GetFieldAccess::ListIndex { key } => write!(f, "({expr})[{key}]"), + GetFieldAccess::ListRange { start, stop } => { + write!(f, "({expr})[{start}:{stop}]") + } + }, Expr::GroupingSet(grouping_sets) => match grouping_sets { GroupingSet::Rollup(exprs) => { // ROLLUP (c0, c1, c2) @@ -1330,9 +1434,22 @@ fn create_name(e: &Expr) -> Result { Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).name().clone()) } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { expr, field }) => { let expr = create_name(expr)?; - Ok(format!("{expr}[{key}]")) + match field { + GetFieldAccess::NamedStructField { name } => { + Ok(format!("{expr}[{name}]")) + } + GetFieldAccess::ListIndex { key } => { + let key = create_name(key)?; + Ok(format!("{expr}[{key}]")) + } + GetFieldAccess::ListRange { start, stop } => { + let start = create_name(start)?; + let stop = create_name(stop)?; + Ok(format!("{expr}[{start}:{stop}]")) + } + } } Expr::ScalarFunction(func) => { create_function_name(&func.fun.to_string(), false, &func.args) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cb5317da4408..ef6ce8171153 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -136,6 +136,17 @@ pub fn sum(expr: Expr) -> Expr { )) } +/// Create an expression to represent the array_agg() aggregate function +pub fn array_agg(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::ArrayAgg, + vec![expr], + false, + None, + None, + )) +} + /// Create an expression to represent the avg() aggregate function pub fn avg(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( @@ -560,10 +571,10 @@ scalar_expr!( "returns an array of the array's dimensions." ); scalar_expr!( - ArrayFill, - array_fill, - element array, - "returns an array filled with copies of the given value." + ArrayElement, + array_element, + array element, + "extracts the element with the index n from the array." ); scalar_expr!( ArrayLength, @@ -595,6 +606,12 @@ scalar_expr!( array element, "prepends an element to the beginning of an array." ); +scalar_expr!( + ArrayRepeat, + array_repeat, + element count, + "returns an array containing element `count` times." +); scalar_expr!( ArrayRemove, array_remove, @@ -631,6 +648,12 @@ scalar_expr!( array from to, "replaces all occurrences of the specified element with another specified element." ); +scalar_expr!( + ArraySlice, + array_slice, + array offset length, + "returns a slice of the array." +); scalar_expr!( ArrayToString, array_to_string, @@ -648,12 +671,6 @@ nary_scalar_expr!( array, "returns an Arrow array using the specified input expressions." ); -scalar_expr!( - TrimArray, - trim_array, - array n, - "removes the last n elements from the array." -); // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); @@ -787,6 +804,7 @@ scalar_expr!( scalar_expr!(CurrentDate, current_date, ,"returns current UTC date as a [`DataType::Date32`] value"); scalar_expr!(Now, now, ,"returns current timestamp in nanoseconds, using the same value for all instances of now() in same statement"); scalar_expr!(CurrentTime, current_time, , "returns current UTC time as a [`DataType::Time64`] value"); +scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); @@ -978,6 +996,7 @@ mod test { test_unary_scalar_expr!(Log10, log10); test_unary_scalar_expr!(Ln, ln); test_scalar_expr!(Atan2, atan2, y, x); + test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Ascii, ascii, input); test_scalar_expr!(BitLength, bit_length, string); @@ -1043,12 +1062,12 @@ mod test { test_scalar_expr!(ArrayAppend, array_append, array, element); test_unary_scalar_expr!(ArrayDims, array_dims); - test_scalar_expr!(ArrayFill, array_fill, element, array); test_scalar_expr!(ArrayLength, array_length, array, dimension); test_unary_scalar_expr!(ArrayNdims, array_ndims); test_scalar_expr!(ArrayPosition, array_position, array, element, index); test_scalar_expr!(ArrayPositions, array_positions, array, element); test_scalar_expr!(ArrayPrepend, array_prepend, array, element); + test_scalar_expr!(ArrayRepeat, array_repeat, element, count); test_scalar_expr!(ArrayRemove, array_remove, array, element); test_scalar_expr!(ArrayRemoveN, array_remove_n, array, element, max); test_scalar_expr!(ArrayRemoveAll, array_remove_all, array, element); @@ -1058,7 +1077,6 @@ mod test { test_scalar_expr!(ArrayToString, array_to_string, array, delimiter); test_unary_scalar_expr!(Cardinality, cardinality); test_nary_scalar_expr!(MakeArray, array, input); - test_scalar_expr!(TrimArray, trim_array, array, n); test_unary_scalar_expr!(ArrowTypeof, arrow_typeof); } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 4a2673bcc97f..1d26485b4e03 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,15 +17,18 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetIndexedField, InList, - InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort, TryCast, WindowFunction, + AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess, + GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarUDF, Sort, + TryCast, WindowFunction, }; -use crate::field_util::get_indexed_field; +use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; use crate::{LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; -use arrow::datatypes::DataType; -use datafusion_common::{Column, DFField, DFSchema, DataFusionError, ExprSchema, Result}; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::{ + plan_err, Column, DFField, DFSchema, DataFusionError, ExprSchema, Result, +}; use std::collections::HashMap; use std::sync::Arc; @@ -153,10 +156,8 @@ impl ExprSchemable for Expr { // grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - let data_type = expr.get_type(schema)?; - - get_indexed_field(&data_type, key).map(|x| x.data_type().clone()) + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + field_for_index(expr, field, schema).map(|x| x.data_type().clone()) } } } @@ -264,9 +265,8 @@ impl ExprSchemable for Expr { "QualifiedWildcard expressions are not valid in a logical query plan" .to_owned(), )), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - let data_type = expr.get_type(input_schema)?; - get_indexed_field(&data_type, key).map(|x| x.is_nullable()) + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + field_for_index(expr, field, input_schema).map(|x| x.is_nullable()) } Expr::GroupingSet(_) => { // grouping sets do not really have the concept of nullable and do not appear @@ -330,11 +330,31 @@ impl ExprSchemable for Expr { _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))), } } else { - Err(DataFusionError::Plan(format!( - "Cannot automatically convert {this_type:?} to {cast_to_type:?}" - ))) + plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}") + } + } +} + +/// return the schema [`Field`] for the type referenced by `get_indexed_field` +fn field_for_index( + expr: &Expr, + field: &GetFieldAccess, + schema: &S, +) -> Result { + let expr_dt = expr.get_type(schema)?; + match field { + GetFieldAccess::NamedStructField { name } => { + GetFieldAccessSchema::NamedStructField { name: name.clone() } } + GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex { + key_dt: key.get_type(schema)?, + }, + GetFieldAccess::ListRange { start, stop } => GetFieldAccessSchema::ListRange { + start_dt: start.get_type(schema)?, + stop_dt: stop.get_type(schema)?, + }, } + .get_accessed_field(&expr_dt) } /// cast subquery in InSubquery/ScalarSubquery to a given type. diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs index feb96928c120..23260ea9c270 100644 --- a/datafusion/expr/src/field_util.rs +++ b/datafusion/expr/src/field_util.rs @@ -18,36 +18,67 @@ //! Utility functions for complex field access use arrow::datatypes::{DataType, Field}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; -/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`] -/// # Error -/// Errors if -/// * the `data_type` is not a Struct or, -/// * there is no field key is not of the required index type -pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result { - match (data_type, key) { - (DataType::List(lt), ScalarValue::Int64(Some(i))) => { - Ok(Field::new(i.to_string(), lt.data_type().clone(), true)) - } - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { - if s.is_empty() { - Err(DataFusionError::Plan( - "Struct based indexed access requires a non empty string".to_string(), - )) - } else { - let field = fields.iter().find(|f| f.name() == s); - field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone()) +/// Types of the field access expression of a nested type, such as `Field` or `List` +pub enum GetFieldAccessSchema { + /// Named field, For example `struct["name"]` + NamedStructField { name: ScalarValue }, + /// Single list index, for example: `list[i]` + ListIndex { key_dt: DataType }, + /// List range, for example `list[i:j]` + ListRange { + start_dt: DataType, + stop_dt: DataType, + }, +} + +impl GetFieldAccessSchema { + /// Returns the schema [`Field`] from a [`DataType::List`] or + /// [`DataType::Struct`] indexed by this structure + /// + /// # Error + /// Errors if + /// * the `data_type` is not a Struct or a List, + /// * the `data_type` of the name/index/start-stop do not match a supported index type + pub fn get_accessed_field(&self, data_type: &DataType) -> Result { + match self { + Self::NamedStructField{ name } => { + match (data_type, name) { + (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { + if s.is_empty() { + plan_err!( + "Struct based indexed access requires a non empty string" + ) + } else { + let field = fields.iter().find(|f| f.name() == s); + field.ok_or(DataFusionError::Plan(format!("Field {s} not found in struct"))).map(|f| f.as_ref().clone()) + } + } + (DataType::Struct(_), _) => plan_err!( + "Only utf8 strings are valid as an indexed field in a struct" + ), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } + } + Self::ListIndex{ key_dt } => { + match (data_type, key_dt) { + (DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)), + (DataType::List(_), _) => plan_err!( + "Only ints are valid as an indexed field in a list" + ), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } + } + Self::ListRange{ start_dt, stop_dt } => { + match (data_type, start_dt, stop_dt) { + (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), + (DataType::List(_), _, _) => plan_err!( + "Only ints are valid as an indexed field in a list" + ), + (other, _, _) => plan_err!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}"), + } } } - (DataType::Struct(_), _) => Err(DataFusionError::Plan( - "Only utf8 strings are valid as an indexed field in a struct".to_string(), - )), - (DataType::List(_), _) => Err(DataFusionError::Plan( - "Only ints are valid as an indexed field in a list".to_string(), - )), - (other, _) => Err(DataFusionError::Plan( - format!("The expression to get an indexed field is only valid for `List` or `Struct` types, got {other}") - )), } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 9eba7124b7c5..d35233bc39d2 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -60,7 +60,8 @@ pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::ColumnarValue; pub use expr::{ - Between, BinaryExpr, Case, Cast, Expr, GetIndexedField, GroupingSet, Like, TryCast, + Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, + Like, TryCast, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 026d9c40c8e4..f89be03f7937 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -40,6 +40,7 @@ use crate::{ Expr, ExprSchemable, TableSource, }; use arrow::datatypes::{DataType, Schema, SchemaRef}; +use datafusion_common::plan_err; use datafusion_common::UnnestOptions; use datafusion_common::{ display::ToStringifiedPlan, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, @@ -129,13 +130,11 @@ impl LogicalPlanBuilder { /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. pub fn values(mut values: Vec>) -> Result { if values.is_empty() { - return Err(DataFusionError::Plan("Values list cannot be empty".into())); + return plan_err!("Values list cannot be empty"); } let n_cols = values[0].len(); if n_cols == 0 { - return Err(DataFusionError::Plan( - "Values list cannot be zero length".into(), - )); + return plan_err!("Values list cannot be zero length"); } let empty_schema = DFSchema::empty(); let mut field_types: Vec> = Vec::with_capacity(n_cols); @@ -146,12 +145,12 @@ impl LogicalPlanBuilder { let mut nulls: Vec<(usize, usize)> = Vec::new(); for (i, row) in values.iter().enumerate() { if row.len() != n_cols { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Inconsistent data length across values list: got {} values in row {} but expected {}", row.len(), i, n_cols - ))); + ); } field_types = row .iter() @@ -164,8 +163,7 @@ impl LogicalPlanBuilder { let data_type = expr.get_type(&empty_schema)?; if let Some(prev_data_type) = &field_types[j] { if prev_data_type != &data_type { - let err = format!("Inconsistent data type across values list at row {i} column {j}"); - return Err(DataFusionError::Plan(err)); + return plan_err!("Inconsistent data type across values list at row {i} column {j}"); } } Ok(Some(data_type)) @@ -239,12 +237,20 @@ impl LogicalPlanBuilder { input: LogicalPlan, table_name: impl Into, table_schema: &Schema, + overwrite: bool, ) -> Result { let table_schema = table_schema.clone().to_dfschema_ref()?; + + let op = if overwrite { + WriteOp::InsertOverwrite + } else { + WriteOp::InsertInto + }; + Ok(Self::from(LogicalPlan::Dml(DmlStatement { table_name: table_name.into(), table_schema, - op: WriteOp::Insert, + op, input: Arc::new(input), }))) } @@ -259,9 +265,7 @@ impl LogicalPlanBuilder { let table_name = table_name.into(); if table_name.table().is_empty() { - return Err(DataFusionError::Plan( - "table_name cannot be empty".to_string(), - )); + return plan_err!("table_name cannot be empty"); } let schema = table_source.schema(); @@ -506,9 +510,7 @@ impl LogicalPlanBuilder { .map(|col| col.flat_name()) .collect::(); - Err(DataFusionError::Plan(format!( - "For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list", - ))) + plan_err!("For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list") } /// Apply a sort @@ -634,9 +636,7 @@ impl LogicalPlanBuilder { null_equals_null: bool, ) -> Result { if join_keys.0.len() != join_keys.1.len() { - return Err(DataFusionError::Plan( - "left_keys and right_keys were not the same length".to_string(), - )); + return plan_err!("left_keys and right_keys were not the same length"); } let filter = if let Some(expr) = filter { @@ -932,9 +932,9 @@ impl LogicalPlanBuilder { let right_len = right_plan.schema().fields().len(); if left_len != right_len { - return Err(DataFusionError::Plan(format!( + return plan_err!( "INTERSECT/EXCEPT query must have the same number of columns. Left is {left_len} and right is {right_len}." - ))); + ); } let join_keys = left_plan @@ -980,9 +980,7 @@ impl LogicalPlanBuilder { filter: Option, ) -> Result { if equi_exprs.0.len() != equi_exprs.1.len() { - return Err(DataFusionError::Plan( - "left_keys and right_keys were not the same length".to_string(), - )); + return plan_err!("left_keys and right_keys were not the same length"); } let join_key_pairs = equi_exprs @@ -1138,12 +1136,10 @@ pub(crate) fn validate_unique_names<'a>( Ok(()) }, Some((existing_position, existing_expr)) => { - Err(DataFusionError::Plan( - format!("{node_name} require unique expression names \ + plan_err!("{node_name} require unique expression names \ but the expression \"{existing_expr}\" at position {existing_position} and \"{expr}\" \ - at position {position} have the same name. Consider aliasing (\"AS\") one of them.", + at position {position} have the same name. Consider aliasing (\"AS\") one of them." ) - )) } } }) @@ -1182,9 +1178,8 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result Result>>()?; if inputs.is_empty() { - return Err(DataFusionError::Plan("Empty UNION".to_string())); + return plan_err!("Empty UNION"); } Ok(LogicalPlan::Union(Union { @@ -1771,9 +1766,7 @@ mod tests { assert_eq!("id", &name); Ok(()) } - _ => Err(DataFusionError::Plan( - "Plan should have returned an DataFusionError::SchemaError".to_string(), - )), + _ => plan_err!("Plan should have returned an DataFusionError::SchemaError"), } } @@ -1800,9 +1793,7 @@ mod tests { assert_eq!("state", &name); Ok(()) } - _ => Err(DataFusionError::Plan( - "Plan should have returned an DataFusionError::SchemaError".to_string(), - )), + _ => plan_err!("Plan should have returned an DataFusionError::SchemaError"), } } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 117a42cda970..07f34101eb3a 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -40,7 +40,8 @@ pub struct DmlStatement { #[derive(Clone, PartialEq, Eq, Hash)] pub enum WriteOp { - Insert, + InsertOverwrite, + InsertInto, Delete, Update, Ctas, @@ -49,7 +50,8 @@ pub enum WriteOp { impl Display for WriteOp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - WriteOp::Insert => write!(f, "Insert"), + WriteOp::InsertOverwrite => write!(f, "Insert Overwrite"), + WriteOp::InsertInto => write!(f, "Insert Into"), WriteOp::Delete => write!(f, "Delete"), WriteOp::Update => write!(f, "Update"), WriteOp::Ctas => write!(f, "Ctas"), diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 67797764cd94..3557745ed346 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -528,23 +528,23 @@ impl LogicalPlan { LogicalPlan::Prepare(prepare_lp) => { // Verify if the number of params matches the number of values if prepare_lp.data_types.len() != param_values.len() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Expected {} parameters, got {}", prepare_lp.data_types.len(), param_values.len() - ))); + ); } // Verify if the types of the params matches the types of the values let iter = prepare_lp.data_types.iter().zip(param_values.iter()); for (i, (param_type, value)) in iter.enumerate() { if *param_type != value.get_datatype() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Expected parameter of type {:?}, got {:?} at index {}", param_type, value.get_datatype(), i - ))); + ); } } @@ -737,9 +737,7 @@ impl LogicalPlan { match (prev, data_type) { (Some(Some(prev)), Some(dt)) => { if prev != dt { - Err(DataFusionError::Plan(format!( - "Conflicting types for {id}" - )))?; + plan_err!("Conflicting types for {id}")?; } } (_, Some(dt)) => { @@ -768,9 +766,7 @@ impl LogicalPlan { match &expr { Expr::Placeholder(Placeholder { id, data_type }) => { if id.is_empty() || id == "$0" { - return Err(DataFusionError::Plan( - "Empty placeholder id".to_string(), - )); + return plan_err!("Empty placeholder id"); } // convert id (in format $1, $2, ..) to idx (0, 1, ..) let idx = id[1..].parse::().map_err(|e| { @@ -1300,7 +1296,7 @@ impl Projection { schema: DFSchemaRef, ) -> Result { if expr.len() != schema.fields().len() { - return Err(DataFusionError::Plan(format!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()))); + return plan_err!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()); } // Update functional dependencies of `input` according to projection // expressions: @@ -1394,19 +1390,19 @@ impl Filter { // ignore errors resolving the expression against the schema. if let Ok(predicate_type) = predicate.get_type(input.schema()) { if predicate_type != DataType::Boolean { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}" - ))); + ); } } // filter predicates should not be aliased if let Expr::Alias(Alias { expr, name, .. }) = predicate { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Attempted to create Filter predicate with \ expression `{expr}` aliased as '{name}'. Filter predicates should not be \ aliased." - ))); + ); } Ok(Self { predicate, input }) @@ -1651,18 +1647,17 @@ impl Aggregate { schema: DFSchemaRef, ) -> Result { if group_expr.is_empty() && aggr_expr.is_empty() { - return Err(DataFusionError::Plan( + return plan_err!( "Aggregate requires at least one grouping or aggregate expression" - .to_string(), - )); + ); } let group_expr_count = grouping_set_expr_count(&group_expr)?; if schema.fields().len() != group_expr_count + aggr_expr.len() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Aggregate schema has wrong number of fields. Expected {} got {}", group_expr_count + aggr_expr.len(), schema.fields().len() - ))); + ); } let aggregate_func_dependencies = diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index c56228d40b0d..f74cc164a7a5 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -342,10 +342,10 @@ impl TreeNode for Expr { Expr::QualifiedWildcard { qualifier } => { Expr::QualifiedWildcard { qualifier } } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { expr, field }) => { Expr::GetIndexedField(GetIndexedField::new( transform_boxed(expr, &mut transform)?, - key, + field, )) } Expr::Placeholder(Placeholder { id, data_type }) => { diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index dec2eb7f1238..2dc806e57c35 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -19,7 +19,7 @@ use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use std::ops::Deref; use crate::{AggregateFunction, Signature, TypeSignature}; @@ -106,10 +106,11 @@ pub fn coerce_types( // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. if !is_sum_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } @@ -117,10 +118,11 @@ pub fn coerce_types( // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval if !is_avg_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } @@ -130,10 +132,11 @@ pub fn coerce_types( // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. if !is_bit_and_or_xor_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } @@ -141,127 +144,131 @@ pub fn coerce_types( // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. if !is_bool_and_or_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Variance => { - if !is_variance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } - AggregateFunction::VariancePop => { + AggregateFunction::Variance | AggregateFunction::VariancePop => { if !is_variance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Covariance => { - if !is_covariance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } - AggregateFunction::CovariancePop => { + AggregateFunction::Covariance | AggregateFunction::CovariancePop => { if !is_covariance_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } - AggregateFunction::Stddev => { + AggregateFunction::Stddev | AggregateFunction::StddevPop => { if !is_stddev_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } - AggregateFunction::StddevPop => { - if !is_stddev_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + AggregateFunction::Correlation => { + if !is_correlation_support_arg_type(&input_types[0]) { + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } - AggregateFunction::Correlation => { - if !is_correlation_support_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + AggregateFunction::RegrSlope + | AggregateFunction::RegrIntercept + | AggregateFunction::RegrCount + | AggregateFunction::RegrR2 + | AggregateFunction::RegrAvgx + | AggregateFunction::RegrAvgy + | AggregateFunction::RegrSXX + | AggregateFunction::RegrSYY + | AggregateFunction::RegrSXY => { + let valid_types = [NUMERICS.to_vec(), vec![DataType::Null]].concat(); + let input_types_valid = // number of input already checked before + valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]); + if !input_types_valid { + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } AggregateFunction::ApproxPercentileCont => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } if input_types.len() == 3 && !is_integer_arg_type(&input_types[2]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The percentile sample points count for {:?} must be integer, not {:?}.", agg_fun, input_types[2] - ))); + ); } let mut result = input_types.to_vec(); if can_coerce_from(&DataType::Float64, &input_types[1]) { result[1] = DataType::Float64; } else { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Could not coerce the percent argument for {:?} to Float64. Was {:?}.", agg_fun, input_types[1] - ))); + ); } Ok(result) } AggregateFunction::ApproxPercentileContWithWeight => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The weight argument for {:?} does not support inputs of type {:?}.", - agg_fun, input_types[1] - ))); + agg_fun, + input_types[1] + ); } if !matches!(input_types[2], DataType::Float64) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The percentile argument for {:?} must be Float64, not {:?}.", - agg_fun, input_types[2] - ))); + agg_fun, + input_types[2] + ); } Ok(input_types.to_vec()) } AggregateFunction::ApproxMedian => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not support inputs of type {:?}.", - agg_fun, input_types[0] - ))); + agg_fun, + input_types[0] + ); } Ok(input_types.to_vec()) } @@ -285,22 +292,22 @@ fn check_arg_count( match signature { TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { if input_types.len() != *agg_count { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} expects {:?} arguments, but {:?} were provided", agg_fun, agg_count, input_types.len() - ))); + ); } } TypeSignature::Exact(types) => { if types.len() != input_types.len() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} expects {:?} arguments, but {:?} were provided", agg_fun, types.len(), input_types.len() - ))); + ); } } TypeSignature::OneOf(variants) => { @@ -308,18 +315,18 @@ fn check_arg_count( .iter() .any(|v| check_arg_count(agg_fun, input_types, v).is_ok()); if !ok { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {:?} does not accept {:?} function arguments.", agg_fun, input_types.len() - ))); + ); } } TypeSignature::VariadicAny => { if input_types.is_empty() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function {agg_fun:?} expects at least one argument" - ))); + ); } } _ => { @@ -370,9 +377,7 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { DataType::Dictionary(_, dict_value_type) => { sum_return_type(dict_value_type.as_ref()) } - other => Err(DataFusionError::Plan(format!( - "SUM does not support type \"{other:?}\"" - ))), + other => plan_err!("SUM does not support type \"{other:?}\""), } } @@ -381,9 +386,7 @@ pub fn variance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - Err(DataFusionError::Plan(format!( - "VAR does not support {arg_type:?}" - ))) + plan_err!("VAR does not support {arg_type:?}") } } @@ -392,9 +395,7 @@ pub fn covariance_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - Err(DataFusionError::Plan(format!( - "COVAR does not support {arg_type:?}" - ))) + plan_err!("COVAR does not support {arg_type:?}") } } @@ -403,9 +404,7 @@ pub fn correlation_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - Err(DataFusionError::Plan(format!( - "CORR does not support {arg_type:?}" - ))) + plan_err!("CORR does not support {arg_type:?}") } } @@ -414,9 +413,7 @@ pub fn stddev_return_type(arg_type: &DataType) -> Result { if NUMERICS.contains(arg_type) { Ok(DataType::Float64) } else { - Err(DataFusionError::Plan(format!( - "STDDEV does not support {arg_type:?}" - ))) + plan_err!("STDDEV does not support {arg_type:?}") } } @@ -441,9 +438,7 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { DataType::Dictionary(_, dict_value_type) => { avg_return_type(dict_value_type.as_ref()) } - other => Err(DataFusionError::Plan(format!( - "AVG does not support {other:?}" - ))), + other => plan_err!("AVG does not support {other:?}"), } } @@ -464,9 +459,7 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result { DataType::Dictionary(_, dict_value_type) => { avg_sum_type(dict_value_type.as_ref()) } - other => Err(DataFusionError::Plan(format!( - "AVG does not support {other:?}" - ))), + other => plan_err!("AVG does not support {other:?}"), } } diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 0113408fcc8b..602448f1a2ff 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -23,8 +23,8 @@ use arrow::datatypes::{ DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use datafusion_common::DataFusionError; use datafusion_common::Result; +use datafusion_common::{plan_err, DataFusionError}; use crate::type_coercion::is_numeric; use crate::Operator; @@ -82,9 +82,9 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result | (DataType::Null, DataType::Null) | (DataType::Boolean, DataType::Null) | (DataType::Null, DataType::Boolean) => Ok(Signature::uniform(DataType::Boolean)), - _ => Err(DataFusionError::Plan(format!( + _ => plan_err!( "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" - ))), + ), }, Operator::RegexMatch | Operator::RegexIMatch | @@ -164,9 +164,9 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result // Numeric arithmetic, e.g. Int32 + Int32 Ok(Signature::uniform(numeric)) } else { - Err(DataFusionError::Plan(format!( + plan_err!( "Cannot coerce arithmetic expression {lhs} {op} {rhs} to valid types" - ))) + ) } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index d86914325fc9..371a3950ac6e 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -20,7 +20,7 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; /// Performs type coercion for function arguments. /// @@ -52,10 +52,11 @@ pub fn data_types( } // none possible -> Error - Err(DataFusionError::Plan(format!( + plan_err!( "Coercion from {:?} to the signature {:?} failed.", - current_types, &signature.type_signature - ))) + current_types, + &signature.type_signature + ) } fn get_valid_types( @@ -84,11 +85,11 @@ fn get_valid_types( TypeSignature::Exact(valid_types) => vec![valid_types.clone()], TypeSignature::Any(number) => { if current_types.len() != *number { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The function expected {} arguments but received {}", number, current_types.len() - ))); + ); } vec![(0..*number).map(|i| current_types[i].clone()).collect()] } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 697aeb49a4da..76061194eddd 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -34,8 +34,8 @@ use datafusion_common::tree_node::{ RewriteRecursion, TreeNode, TreeNodeRewriter, VisitRecursion, }; use datafusion_common::{ - Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, - ScalarValue, TableReference, + plan_err, Column, Constraints, DFField, DFSchema, DFSchemaRef, DataFusionError, + Result, ScalarValue, TableReference, }; use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; use std::cmp::Ordering; @@ -60,10 +60,9 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { if group_expr.len() > 1 { - return Err(DataFusionError::Plan( + return plan_err!( "Invalid group by expressions, GroupingSet must be the only expression" - .to_string(), - )); + ); } Ok(grouping_set.distinct_expr().len()) } else { @@ -114,7 +113,7 @@ fn powerset(slice: &[T]) -> Result>, String> { fn check_grouping_set_size_limit(size: usize) -> Result<()> { let max_grouping_set_size = 65535; if size > max_grouping_set_size { - return Err(DataFusionError::Plan(format!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"))); + return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"); } Ok(()) @@ -124,7 +123,7 @@ fn check_grouping_set_size_limit(size: usize) -> Result<()> { fn check_grouping_sets_size_limit(size: usize) -> Result<()> { let max_grouping_sets_size = 4096; if size > max_grouping_sets_size { - return Err(DataFusionError::Plan(format!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"))); + return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"); } Ok(()) @@ -252,10 +251,9 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { if group_expr.len() > 1 { - return Err(DataFusionError::Plan( + return plan_err!( "Invalid group by expressions, GroupingSet must be the only expression" - .to_string(), - )); + ); } Ok(grouping_set.distinct_expr()) } else { @@ -340,9 +338,7 @@ fn get_excluded_columns( // if HashSet size, and vector length are different, this means that some of the excluded columns // are not unique. In this case return error. if n_elem != unique_idents.len() { - return Err(DataFusionError::Plan( - "EXCLUDE or EXCEPT contains duplicate column names".to_string(), - )); + return plan_err!("EXCLUDE or EXCEPT contains duplicate column names"); } let mut result = vec![]; @@ -441,9 +437,7 @@ pub fn expand_qualified_wildcard( .cloned() .collect(); if qualified_fields.is_empty() { - return Err(DataFusionError::Plan(format!( - "Invalid qualifier {qualifier}" - ))); + return plan_err!("Invalid qualifier {qualifier}"); } let qualified_schema = DFSchema::new_with_metadata(qualified_fields, schema.metadata().clone())? @@ -480,9 +474,7 @@ pub fn generate_sort_key( Expr::Sort(Sort { expr, .. }) => { Ok(Expr::Sort(Sort::new(expr.clone(), true, false))) } - _ => Err(DataFusionError::Plan( - "Order by only accepts sort expressions".to_string(), - )), + _ => plan_err!("Order by only accepts sort expressions"), }) .collect::>>()?; @@ -964,15 +956,11 @@ pub fn from_plan( // If this check cannot pass it means some optimizer pass is // trying to optimize Explain directly if expr.is_empty() { - return Err(DataFusionError::Plan( - "Invalid EXPLAIN command. Expression is empty".to_string(), - )); + return plan_err!("Invalid EXPLAIN command. Expression is empty"); } if inputs.is_empty() { - return Err(DataFusionError::Plan( - "Invalid EXPLAIN command. Inputs are empty".to_string(), - )); + return plan_err!("Invalid EXPLAIN command. Inputs are empty"); } Ok(plan.clone()) diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index b2d1882788be..5c9cfcae663c 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,7 +23,7 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; use std::convert::{From, TryFrom}; @@ -68,14 +68,14 @@ impl TryFrom for WindowFrame { if let WindowFrameBound::Following(val) = &start_bound { if val.is_null() { - plan_error( - "Invalid window frame: start bound cannot be UNBOUNDED FOLLOWING", + plan_err!( + "Invalid window frame: start bound cannot be UNBOUNDED FOLLOWING" )? } } else if let WindowFrameBound::Preceding(val) = &end_bound { if val.is_null() { - plan_error( - "Invalid window frame: end bound cannot be UNBOUNDED PRECEDING", + plan_err!( + "Invalid window frame: end bound cannot be UNBOUNDED PRECEDING" )? } }; @@ -161,10 +161,10 @@ pub fn regularize(mut frame: WindowFrame, order_bys: usize) -> Result Result result } } - _ => plan_error( - "Invalid window frame: frame offsets must be non negative integers", + _ => plan_err!( + "Invalid window frame: frame offsets must be non negative integers" )?, }))) } -fn plan_error(err_message: &str) -> Result { - Err(DataFusionError::Plan(err_message.to_string())) -} - impl fmt::Display for WindowFrameBound { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 1c847a4541c5..89c59baa4c29 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -24,7 +24,7 @@ use crate::aggregate_function::AggregateFunction; use crate::type_coercion::functions::data_types; use crate::{AggregateUDF, Signature, TypeSignature, Volatility, WindowUDF}; use arrow::datatypes::DataType; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use std::sync::Arc; use std::{fmt, str::FromStr}; use strum_macros::EnumIter; @@ -145,11 +145,7 @@ impl FromStr for BuiltInWindowFunction { "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, "LAST_VALUE" => BuiltInWindowFunction::LastValue, "NTH_VALUE" => BuiltInWindowFunction::NthValue, - _ => { - return Err(DataFusionError::Plan(format!( - "There is no built-in window function named {name}" - ))) - } + _ => return plan_err!("There is no built-in window function named {name}"), }) } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index ce798781c6e1..912ac069e0b6 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -75,7 +75,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .collect::>>()?; Ok(Transformed::Yes(LogicalPlan::Aggregate( - Aggregate::try_new(agg.input.clone(), agg.group_expr.clone(), aggr_expr)?, + Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?, ))) } LogicalPlan::Sort(Sort { expr, input, fetch }) => { diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index c36a74c328bd..fee637711437 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -18,7 +18,7 @@ use crate::analyzer::check_plan; use crate::utils::{collect_subquery_cols, split_conjunction}; use datafusion_common::tree_node::{TreeNode, VisitRecursion}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::{ Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, @@ -42,11 +42,11 @@ pub fn check_subquery_expr( if let Expr::ScalarSubquery(subquery) = expr { // Scalar subquery should only return one column if subquery.subquery.schema().fields().len() > 1 { - return Err(datafusion_common::DataFusionError::Plan(format!( + return plan_err!( "Scalar subquery should only return one column, but found {}: {}", subquery.subquery.schema().fields().len(), - subquery.subquery.schema().field_names().join(", "), - ))); + subquery.subquery.schema().field_names().join(", ") + ); } // Correlated scalar subquery must be aggregated to return at most one row if !subquery.outer_ref_columns.is_empty() { @@ -71,10 +71,9 @@ pub fn check_subquery_expr( { Ok(()) } else { - Err(DataFusionError::Plan( + plan_err!( "Correlated scalar subquery must be aggregated to return at most one row" - .to_string(), - )) + ) } } }?; @@ -84,18 +83,16 @@ pub fn check_subquery_expr( LogicalPlan::Aggregate(Aggregate {group_expr, aggr_expr,..}) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic - Err(DataFusionError::Plan( + plan_err!( "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions" - .to_string(), - )) + ) } else { Ok(()) } }, - _ => Err(DataFusionError::Plan( + _ => plan_err!( "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" - .to_string(), - )) + ) }?; } check_correlations_in_subquery(inner_plan, true) @@ -103,11 +100,11 @@ pub fn check_subquery_expr( if let Expr::InSubquery(subquery) = expr { // InSubquery should only return one column if subquery.subquery.subquery.schema().fields().len() > 1 { - return Err(datafusion_common::DataFusionError::Plan(format!( + return plan_err!( "InSubquery should only return one column, but found {}: {}", subquery.subquery.subquery.schema().fields().len(), - subquery.subquery.subquery.schema().field_names().join(", "), - ))); + subquery.subquery.subquery.schema().field_names().join(", ") + ); } } match outer_plan { @@ -116,11 +113,10 @@ pub fn check_subquery_expr( | LogicalPlan::Window(_) | LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => Ok(()), - _ => Err(DataFusionError::Plan( + _ => plan_err!( "In/Exist subquery can only be used in \ Projection, Filter, Window functions, Aggregate and Join plan nodes" - .to_string(), - )), + ), }?; check_correlations_in_subquery(inner_plan, false) } @@ -142,9 +138,7 @@ fn check_inner_plan( can_contain_outer_ref: bool, ) -> Result<()> { if !can_contain_outer_ref && contains_outer_reference(inner_plan) { - return Err(DataFusionError::Plan( - "Accessing outer reference columns is not allowed in the plan".to_string(), - )); + return plan_err!("Accessing outer reference columns is not allowed in the plan"); } // We want to support as many operators as possible inside the correlated subquery match inner_plan { @@ -166,9 +160,9 @@ fn check_inner_plan( .filter(|expr| !can_pullup_over_aggregation(expr)) .collect::>(); if is_aggregate && is_scalar && !maybe_unsupport.is_empty() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Correlated column is not allowed in predicate: {predicate}" - ))); + ); } check_inner_plan(input, is_scalar, is_aggregate, can_contain_outer_ref) } @@ -231,9 +225,7 @@ fn check_inner_plan( Ok(()) } }, - _ => Err(DataFusionError::Plan( - "Unsupported operator in the subquery plan.".to_string(), - )), + _ => plan_err!("Unsupported operator in the subquery plan."), } } @@ -249,10 +241,9 @@ fn check_aggregation_in_scalar_subquery( agg: &Aggregate, ) -> Result<()> { if agg.aggr_expr.is_empty() { - return Err(DataFusionError::Plan( + return plan_err!( "Correlated scalar subquery must be aggregated to return at most one row" - .to_string(), - )); + ); } if !agg.group_expr.is_empty() { let correlated_exprs = get_correlated_expressions(inner_plan)?; @@ -268,10 +259,9 @@ fn check_aggregation_in_scalar_subquery( if !group_columns.all(|group| inner_subquery_cols.contains(&group)) { // Group BY columns must be a subset of columns in the correlated expressions - return Err(DataFusionError::Plan( + return plan_err!( "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns" - .to_string(), - )); + ); } } Ok(()) @@ -341,10 +331,9 @@ fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> { win_expr.contains_outer() && !win_expr.to_columns().unwrap().is_empty() }); if mixed { - Err(DataFusionError::Plan( + plan_err!( "Window expressions should not contain a mixed of outer references and inner columns" - .to_string(), - )) + ) } else { Ok(()) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 2bd7fd67d8c9..9d313e847433 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -23,7 +23,9 @@ use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; -use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, ScalarUDF, WindowFunction, @@ -293,9 +295,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); match result_type { - None => Err(DataFusionError::Plan(format!( + None => plan_err!( "Can not find compatible types to compare {expr_data_type:?} with {list_data_types:?}" - ))), + ), Some(coerced_type) => { // find the coerced type let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 8cd90072ffcc..b5cf73733896 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -22,7 +22,7 @@ use crate::utils::{ use datafusion_common::tree_node::{ RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, }; -use datafusion_common::Result; +use datafusion_common::{plan_err, Result}; use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; @@ -132,9 +132,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { match (&pull_up_expr_opt, &self.pull_up_having_expr) { (Some(_), Some(_)) => { // Error path - Err(DataFusionError::Plan( - "Unsupported Subquery plan".to_string(), - )) + plan_err!("Unsupported Subquery plan") } (Some(_), None) => { self.pull_up_having_expr = pull_up_expr_opt; diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index d3500304af0e..432d7f053aef 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -21,7 +21,7 @@ use crate::utils::{conjunction, replace_qualified_name, split_conjunction}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::TreeNode; -use datafusion_common::{Column, DataFusionError, Result}; +use datafusion_common::{plan_err, Column, DataFusionError, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; @@ -204,12 +204,13 @@ fn build_join( let in_predicate_opt = where_in_expr_opt .clone() .map(|where_in_expr| { - query_info.query.subquery.head_output_expr()?.map_or( - Err(DataFusionError::Plan( - "single expression required.".to_string(), - )), - |expr| Ok(Expr::eq(where_in_expr, expr)), - ) + query_info + .query + .subquery + .head_output_expr()? + .map_or(plan_err!("single expression required."), |expr| { + Ok(Expr::eq(where_in_expr, expr)) + }) }) .map_or(Ok(None), |v| v.map(Some))?; diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 533566a0bf69..ec4d8a2cbf1d 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -20,7 +20,7 @@ use std::collections::HashSet; use std::sync::Arc; use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, @@ -155,9 +155,7 @@ fn flatten_join_inputs( vec![left, right] } _ => { - return Err(DataFusionError::Plan( - "flatten_join_inputs just can call join/cross_join".to_string(), - )); + return plan_err!("flatten_join_inputs just can call join/cross_join"); } }; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 920b9ea18f92..caac4c34bdd3 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -447,7 +447,9 @@ mod tests { use crate::optimizer::Optimizer; use crate::test::test_table_scan; use crate::{OptimizerConfig, OptimizerContext, OptimizerRule}; - use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result}; + use datafusion_common::{ + plan_err, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, + }; use datafusion_expr::logical_plan::EmptyRelation; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; use std::sync::{Arc, Mutex}; @@ -613,7 +615,7 @@ mod tests { _: &LogicalPlan, _: &dyn OptimizerConfig, ) -> Result> { - Err(DataFusionError::Plan("rule failed".to_string())) + plan_err!("rule failed") } fn name(&self) -> &str { diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 01e16058ec32..4de7596b329c 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{EmptyRelation, JoinType, Projection, Union}; use std::sync::Arc; @@ -156,9 +156,7 @@ fn binary_plan_children_is_empty(plan: &LogicalPlan) -> Result<(bool, bool)> { }; Ok((left_empty, right_empty)) } - _ => Err(DataFusionError::Plan( - "plan just can have two child".to_string(), - )), + _ => plan_err!("plan just can have two child"), } } @@ -177,9 +175,7 @@ fn empty_child(plan: &LogicalPlan) -> Result> { } _ => Ok(None), }, - _ => Err(DataFusionError::Plan( - "plan just can have one child".to_string(), - )), + _ => plan_err!("plan just can have one child"), } } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index eb9ae3c981d9..0469f678e09e 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -26,7 +26,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::error::Result as ArrowResult; use datafusion_common::ScalarValue::UInt8; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ToDFSchema, + plan_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ToDFSchema, }; use datafusion_expr::expr::{AggregateFunction, Alias}; use datafusion_expr::utils::exprlist_to_fields; @@ -443,9 +443,7 @@ fn get_expr(columns: &HashSet, schema: &DFSchemaRef) -> Result }) .collect::>(); if columns.len() != expr.len() { - Err(DataFusionError::Plan(format!( - "required columns can't push down, columns: {columns:?}" - ))) + plan_err!("required columns can't push down, columns: {columns:?}") } else { Ok(expr) } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 357a2775cc0c..96d2f45d808e 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -23,7 +23,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, }; -use datafusion_common::{Column, DataFusionError, Result, ScalarValue}; +use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; @@ -215,12 +215,10 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { let subqry_alias = self.alias_gen.next("__scalar_sq"); self.sub_query_info .push((subquery.clone(), subqry_alias.clone())); - let scalar_expr = subquery.subquery.head_output_expr()?.map_or( - Err(DataFusionError::Plan( - "single expression required.".to_string(), - )), - Ok, - )?; + let scalar_expr = subquery + .subquery + .head_output_expr()? + .map_or(plan_err!("single expression required."), Ok)?; Ok(Expr::Column(create_col_from_scalar_expr( &scalar_expr, subqry_alias, diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5562e10f693b..b7e8612d538e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -2499,10 +2499,43 @@ mod tests { col("c1") .in_list(vec![lit("foo"), lit("bar"), lit("baz"), lit("qux")], false), ); + assert_change( + regex_match(col("c1"), lit("^(fo_o)$")), + col("c1").eq(lit("fo_o")), + ); + assert_change( + regex_match(col("c1"), lit("^(fo_o)$")), + col("c1").eq(lit("fo_o")), + ); + assert_change( + regex_match(col("c1"), lit("^(fo_o|ba_r)$")), + col("c1").eq(lit("fo_o")).or(col("c1").eq(lit("ba_r"))), + ); + assert_change( + regex_not_match(col("c1"), lit("^(fo_o|ba_r)$")), + col("c1") + .not_eq(lit("fo_o")) + .and(col("c1").not_eq(lit("ba_r"))), + ); + assert_change( + regex_match(col("c1"), lit("^(fo_o|ba_r|ba_z)$")), + ((col("c1").eq(lit("fo_o"))).or(col("c1").eq(lit("ba_r")))) + .or(col("c1").eq(lit("ba_z"))), + ); + assert_change( + regex_match(col("c1"), lit("^(fo_o|ba_r|baz|qu_x)$")), + col("c1").in_list( + vec![lit("fo_o"), lit("ba_r"), lit("baz"), lit("qu_x")], + false, + ), + ); // regular expressions that mismatch captured literals assert_no_change(regex_match(col("c1"), lit("(foo|bar)"))); assert_no_change(regex_match(col("c1"), lit("(foo|bar)*"))); + assert_no_change(regex_match(col("c1"), lit("(fo_o|b_ar)"))); + assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*"))); + assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*"))); assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*"))); assert_no_change(regex_match(col("c1"), lit("^foo|bar$"))); assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$"))); diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 27fcfc5dbf47..5094623b82c0 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -108,7 +108,7 @@ fn collect_concat_to_like_string(parts: &[Hir]) -> Option { for sub in parts { if let HirKind::Literal(l) = sub.kind() { - s.push_str(str_from_literal(l)?); + s.push_str(like_str_from_literal(l)?); } else { return None; } @@ -120,7 +120,7 @@ fn collect_concat_to_like_string(parts: &[Hir]) -> Option { /// returns a str represented by `Literal` if it contains a valid utf8 /// sequence and is safe for like (has no '%' and '_') -fn str_from_literal(l: &Literal) -> Option<&str> { +fn like_str_from_literal(l: &Literal) -> Option<&str> { // if not utf8, no good let s = std::str::from_utf8(&l.0).ok()?; @@ -131,6 +131,14 @@ fn str_from_literal(l: &Literal) -> Option<&str> { } } +/// returns a str represented by `Literal` if it contains a valid utf8 +fn str_from_literal(l: &Literal) -> Option<&str> { + // if not utf8, no good + let s = std::str::from_utf8(&l.0).ok()?; + + Some(s) +} + fn is_safe_for_like(c: char) -> bool { (c != '%') && (c != '_') } @@ -196,7 +204,7 @@ fn anchored_literal_to_expr(v: &[Hir]) -> Option { 2 => Some(lit("")), 3 => { let HirKind::Literal(l) = v[1].kind() else { return None }; - str_from_literal(l).map(lit) + like_str_from_literal(l).map(lit) } _ => None, } @@ -242,7 +250,7 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { return Some(mode.expr(Box::new(left.clone()), "%".to_owned())); } HirKind::Literal(l) => { - let s = str_from_literal(l)?; + let s = like_str_from_literal(l)?; return Some(mode.expr(Box::new(left.clone()), format!("%{s}%"))); } HirKind::Concat(inner) if is_anchored_literal(inner) => { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index adb3bf6302fc..a3e7e42875d7 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,6 +18,7 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::DataFusionError; use datafusion_common::{plan_err, Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::{Alias, BinaryExpr}; diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 8027e6882ecc..142ae870d4b8 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -18,7 +18,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; @@ -392,7 +392,7 @@ impl ContextProvider for MySchemaProvider { schema: Arc::new(schema), })) } else { - Err(DataFusionError::Plan("table does not exist".to_string())) + plan_err!("table does not exist") } } diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index f0a44cc97a66..46825a5fc049 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -27,6 +27,7 @@ use arrow::{ }, datatypes::{DataType, Field}, }; +use datafusion_common::plan_err; use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_common::{downcast_value, ScalarValue}; @@ -152,9 +153,9 @@ fn validate_input_percentile_expr(expr: &Arc) -> Result { // Ensure the percentile is between 0 and 1. if !(0.0..=1.0).contains(&percentile) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" - ))); + ); } Ok(percentile) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index f0eb442af012..d91f06fc76e5 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -18,9 +18,16 @@ //! Defines physical expressions which specify ordering requirement //! that can evaluated at runtime during query execution +use std::any::Any; +use std::cmp::Ordering; +use std::collections::BinaryHeap; +use std::fmt::Debug; +use std::sync::Arc; + use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; use crate::expressions::format_state_name; use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; + use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; use arrow_array::{Array, ListArray}; @@ -28,12 +35,8 @@ use arrow_schema::{Fields, SortOptions}; use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; + use itertools::izip; -use std::any::Any; -use std::cmp::Ordering; -use std::collections::BinaryHeap; -use std::fmt::Debug; -use std::sync::Arc; /// Expression for a ARRAY_AGG(ORDER BY) aggregation. /// When aggregation works in multiple partitions @@ -100,14 +103,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg { let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), - Field::new( - "item", - DataType::Struct(Fields::from(orderings.clone())), - true, - ), + Field::new("item", DataType::Struct(Fields::from(orderings)), true), false, )); - fields.extend(orderings); Ok(fields) } @@ -207,6 +205,10 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { let mut partition_values = vec![]; // Stores ordering requirement expression results coming from each partition let mut partition_ordering_values = vec![]; + + // Existing values should be merged also. + partition_values.push(self.values.clone()); + partition_ordering_values.push(self.ordering_values.clone()); for index in 0..agg_orderings.len() { let ordering = ScalarValue::try_from_array(agg_orderings, index)?; // Ordering requirement expression values for each entry in the ARRAY_AGG list @@ -228,11 +230,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { .iter() .map(|sort_expr| sort_expr.options) .collect::>(); - self.values = merge_ordered_arrays( + let (new_values, new_orderings) = merge_ordered_arrays( &partition_values, &partition_ordering_values, &sort_options, )?; + self.values = new_values; + self.ordering_values = new_orderings; } else { return Err(DataFusionError::Execution( "Expects to receive a list array".to_string(), @@ -244,17 +248,6 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { fn state(&self) -> Result> { let mut result = vec![self.evaluate()?]; result.push(self.evaluate_orderings()?); - let last_ordering = if let Some(ordering) = self.ordering_values.last() { - ordering.clone() - } else { - // In case ordering is empty, construct ordering as NULL: - self.datatypes - .iter() - .skip(1) - .map(ScalarValue::try_from) - .collect::>>()? - }; - result.extend(last_ordering); Ok(result) } @@ -426,7 +419,7 @@ fn merge_ordered_arrays( ordering_values: &[Vec>], // Defines according to which ordering comparisons should be done. sort_options: &[SortOptions], -) -> Result> { +) -> Result<(Vec, Vec>)> { // Keep track the most recent data of each branch, in binary heap data structure. let mut heap: BinaryHeap = BinaryHeap::new(); @@ -449,6 +442,7 @@ fn merge_ordered_arrays( .map(|idx| values[idx].len()) .collect::>(); let mut merged_values = vec![]; + let mut merged_orderings = vec![]; // Continue iterating the loop until consuming data of all branches. loop { let min_elem = if let Some(min_elem) = heap.pop() { @@ -490,6 +484,7 @@ fn merge_ordered_arrays( indices[branch_idx] += 1; let row_idx = indices[branch_idx]; merged_values.push(min_elem.value.clone()); + merged_orderings.push(min_elem.ordering.clone()); if row_idx < end_indices[branch_idx] { // Push next entry in the most recently consumed branch to the heap // If there is an available entry @@ -500,7 +495,7 @@ fn merge_ordered_arrays( } } - Ok(merged_values) + Ok((merged_values, merged_orderings)) } #[cfg(test)] @@ -553,14 +548,28 @@ mod tests { .collect::>>()?; let expected = Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef; + let expected_ts = vec![ + Arc::new(Int64Array::from(vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2])) as ArrayRef, + Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 3, 3, 4, 4])) as ArrayRef, + ]; - let merged_vals = merge_ordered_arrays( + let (merged_vals, merged_ts) = merge_ordered_arrays( &[lhs_vals, rhs_vals], &[lhs_orderings, rhs_orderings], &sort_options, )?; let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; + let merged_ts = (0..merged_ts[0].len()) + .map(|col_idx| { + ScalarValue::iter_to_array( + (0..merged_ts.len()) + .map(|row_idx| merged_ts[row_idx][col_idx].clone()), + ) + }) + .collect::>>()?; + assert_eq!(&merged_vals, &expected); + assert_eq!(&merged_ts, &expected_ts); Ok(()) } @@ -607,15 +616,27 @@ mod tests { .collect::>>()?; let expected = Arc::new(Int64Array::from(vec![0, 0, 1, 1, 2, 2, 1, 1, 2, 2])) as ArrayRef; - - let merged_vals = merge_ordered_arrays( + let expected_ts = vec![ + Arc::new(Int64Array::from(vec![2, 2, 1, 1, 1, 1, 0, 0, 0, 0])) as ArrayRef, + Arc::new(Int64Array::from(vec![4, 4, 3, 3, 2, 2, 1, 1, 0, 0])) as ArrayRef, + ]; + let (merged_vals, merged_ts) = merge_ordered_arrays( &[lhs_vals, rhs_vals], &[lhs_orderings, rhs_orderings], &sort_options, )?; let merged_vals = ScalarValue::iter_to_array(merged_vals.into_iter())?; + let merged_ts = (0..merged_ts[0].len()) + .map(|col_idx| { + ScalarValue::iter_to_array( + (0..merged_ts.len()) + .map(|row_idx| merged_ts[row_idx][col_idx].clone()), + ) + }) + .collect::>>()?; assert_eq!(&merged_vals, &expected); + assert_eq!(&merged_ts, &expected_ts); Ok(()) } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 4dc7c824edd0..bbccb6502665 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -26,6 +26,7 @@ //! * Signature: see `Signature` //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. +use crate::aggregate::regr::RegrType; use crate::{expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr}; use arrow::datatypes::Schema; use datafusion_common::{DataFusionError, Result}; @@ -248,6 +249,86 @@ pub fn create_aggregate_expr( "CORR(DISTINCT) aggregations are not available".to_string(), )); } + (AggregateFunction::RegrSlope, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::Slope, + rt_type, + )), + (AggregateFunction::RegrIntercept, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::Intercept, + rt_type, + )), + (AggregateFunction::RegrCount, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::Count, + rt_type, + )), + (AggregateFunction::RegrR2, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::R2, + rt_type, + )), + (AggregateFunction::RegrAvgx, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::AvgX, + rt_type, + )), + (AggregateFunction::RegrAvgy, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::AvgY, + rt_type, + )), + (AggregateFunction::RegrSXX, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::SXX, + rt_type, + )), + (AggregateFunction::RegrSYY, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::SYY, + rt_type, + )), + (AggregateFunction::RegrSXY, false) => Arc::new(expressions::Regr::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + RegrType::SXY, + rt_type, + )), + ( + AggregateFunction::RegrSlope + | AggregateFunction::RegrIntercept + | AggregateFunction::RegrCount + | AggregateFunction::RegrR2 + | AggregateFunction::RegrAvgx + | AggregateFunction::RegrAvgy + | AggregateFunction::RegrSXX + | AggregateFunction::RegrSYY + | AggregateFunction::RegrSXY, + true, + ) => { + return Err(DataFusionError::NotImplemented(format!( + "{}(DISTINCT) aggregations are not available", + fun + ))); + } (AggregateFunction::ApproxPercentileCont, false) => { if input_phy_exprs.len() == 2 { Arc::new(expressions::ApproxPercentileCont::new( @@ -333,6 +414,7 @@ mod tests { DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, }; use arrow::datatypes::{DataType, Field}; + use datafusion_common::plan_err; use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::{type_coercion, Signature}; @@ -1213,9 +1295,9 @@ mod tests { let coerced_phy_exprs = coerce_exprs_for_test(fun, input_phy_exprs, input_schema, &fun.signature())?; if coerced_phy_exprs.is_empty() { - return Err(DataFusionError::Plan(format!( - "Invalid or wrong number of arguments passed to aggregate: '{name}'", - ))); + return plan_err!( + "Invalid or wrong number of arguments passed to aggregate: '{name}'" + ); } create_aggregate_expr(fun, distinct, &coerced_phy_exprs, &[], input_schema, name) } diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 656f30a13504..7e8930ce2a32 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -17,21 +17,23 @@ //! Defines the FIRST_VALUE/LAST_VALUE aggregations. +use std::any::Any; +use std::sync::Arc; + use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; use crate::expressions::format_state_name; use crate::{AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr}; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Accumulator; - use arrow::compute; +use arrow::compute::{lexsort_to_indices, SortColumn}; +use arrow::datatypes::{DataType, Field}; use arrow_array::cast::AsArray; -use datafusion_common::utils::get_row_at_idx; -use std::any::Any; -use std::sync::Arc; +use arrow_array::{Array, BooleanArray}; +use arrow_schema::SortOptions; +use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; /// FIRST_VALUE aggregate expression #[derive(Debug)] @@ -76,6 +78,7 @@ impl AggregateExpr for FirstValue { Ok(Box::new(FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, + self.ordering_req.clone(), )?)) } @@ -132,6 +135,7 @@ impl AggregateExpr for FirstValue { Ok(Box::new(FirstValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, + self.ordering_req.clone(), )?)) } } @@ -159,11 +163,17 @@ struct FirstValueAccumulator { // Stores ordering values, of the aggregator requirement corresponding to first value // of the aggregator. These values are used during merging of multiple partitions. orderings: Vec, + // Stores the applicable ordering requirement. + ordering_req: LexOrdering, } impl FirstValueAccumulator { /// Creates a new `FirstValueAccumulator` for the given `data_type`. - pub fn try_new(data_type: &DataType, ordering_dtypes: &[DataType]) -> Result { + pub fn try_new( + data_type: &DataType, + ordering_dtypes: &[DataType], + ordering_req: LexOrdering, + ) -> Result { let orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) @@ -172,8 +182,16 @@ impl FirstValueAccumulator { first: value, is_set: false, orderings, + ordering_req, }) } + + // Updates state with the values in the given row. + fn update_with_new_row(&mut self, row: &[ScalarValue]) { + self.first = row[0].clone(); + self.orderings = row[1..].to_vec(); + self.is_set = true; + } } impl Accumulator for FirstValueAccumulator { @@ -188,24 +206,41 @@ impl Accumulator for FirstValueAccumulator { // If we have seen first value, we shouldn't update it if !values[0].is_empty() && !self.is_set { let row = get_row_at_idx(values, 0)?; - // Update with last value in the array. - self.first = row[0].clone(); - self.orderings = row[1..].to_vec(); - self.is_set = true; + // Update with first value in the array. + self.update_with_new_row(&row); } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { // FIRST_VALUE(first1, first2, first3, ...) - let last_idx = states.len() - 1; - let is_set_flags = &states[last_idx]; - let flags = is_set_flags.as_boolean(); - let mut filtered_first_vals = vec![]; - for state in states.iter().take(last_idx) { - filtered_first_vals.push(compute::filter(state, flags)?) + // last index contains is_set flag. + let is_set_idx = states.len() - 1; + let flags = states[is_set_idx].as_boolean(); + let filtered_states = filter_states_according_to_is_set(states, flags)?; + // 1..is_set_idx range corresponds to ordering section + let sort_cols = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); + + let ordered_states = if sort_cols.is_empty() { + // When no ordering is given, use the existing state as is: + filtered_states + } else { + let indices = lexsort_to_indices(&sort_cols, None)?; + get_arrayref_at_indices(&filtered_states, &indices)? + }; + if !ordered_states[0].is_empty() { + let first_row = get_row_at_idx(&ordered_states, 0)?; + let first_ordering = &first_row[1..]; + let sort_options = get_sort_options(&self.ordering_req); + // Either there is no existing value, or there is an earlier version in new data. + if !self.is_set + || compare_rows(first_ordering, &self.orderings, &sort_options)?.is_lt() + { + self.update_with_new_row(&first_row); + } } - self.update_batch(&filtered_first_vals) + Ok(()) } fn evaluate(&self) -> Result { @@ -263,6 +298,7 @@ impl AggregateExpr for LastValue { Ok(Box::new(LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, + self.ordering_req.clone(), )?)) } @@ -319,6 +355,7 @@ impl AggregateExpr for LastValue { Ok(Box::new(LastValueAccumulator::try_new( &self.input_data_type, &self.order_by_data_types, + self.ordering_req.clone(), )?)) } } @@ -345,11 +382,17 @@ struct LastValueAccumulator { // occur due to empty partitions. is_set: bool, orderings: Vec, + // Stores the applicable ordering requirement. + ordering_req: LexOrdering, } impl LastValueAccumulator { /// Creates a new `LastValueAccumulator` for the given `data_type`. - pub fn try_new(data_type: &DataType, ordering_dtypes: &[DataType]) -> Result { + pub fn try_new( + data_type: &DataType, + ordering_dtypes: &[DataType], + ordering_req: LexOrdering, + ) -> Result { let orderings = ordering_dtypes .iter() .map(ScalarValue::try_from) @@ -358,8 +401,16 @@ impl LastValueAccumulator { last: ScalarValue::try_from(data_type)?, is_set: false, orderings, + ordering_req, }) } + + // Updates state with the values in the given row. + fn update_with_new_row(&mut self, row: &[ScalarValue]) { + self.last = row[0].clone(); + self.orderings = row[1..].to_vec(); + self.is_set = true; + } } impl Accumulator for LastValueAccumulator { @@ -374,23 +425,43 @@ impl Accumulator for LastValueAccumulator { if !values[0].is_empty() { let row = get_row_at_idx(values, values[0].len() - 1)?; // Update with last value in the array. - self.last = row[0].clone(); - self.orderings = row[1..].to_vec(); - self.is_set = true; + self.update_with_new_row(&row); } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { // LAST_VALUE(last1, last2, last3, ...) - let last_idx = states.len() - 1; - let is_set_flags = &states[last_idx]; - let flags = is_set_flags.as_boolean(); - let mut filtered_first_vals = vec![]; - for state in states.iter().take(last_idx) { - filtered_first_vals.push(compute::filter(state, flags)?) + // last index contains is_set flag. + let is_set_idx = states.len() - 1; + let flags = states[is_set_idx].as_boolean(); + let filtered_states = filter_states_according_to_is_set(states, flags)?; + // 1..is_set_idx range corresponds to ordering section + let sort_cols = + convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); + + let ordered_states = if sort_cols.is_empty() { + // When no ordering is given, use existing state as is: + filtered_states + } else { + let indices = lexsort_to_indices(&sort_cols, None)?; + get_arrayref_at_indices(&filtered_states, &indices)? + }; + + if !ordered_states[0].is_empty() { + let last_idx = ordered_states[0].len() - 1; + let last_row = get_row_at_idx(&ordered_states, last_idx)?; + let last_ordering = &last_row[1..]; + let sort_options = get_sort_options(&self.ordering_req); + // Either there is no existing value, or there is a newer (latest) + // version in the new data: + if !self.is_set + || compare_rows(last_ordering, &self.orderings, &sort_options)?.is_gt() + { + self.update_with_new_row(&last_row); + } } - self.update_batch(&filtered_first_vals) + Ok(()) } fn evaluate(&self) -> Result { @@ -405,20 +476,57 @@ impl Accumulator for LastValueAccumulator { } } +/// Filters states according to the `is_set` flag at the last column and returns +/// the resulting states. +fn filter_states_according_to_is_set( + states: &[ArrayRef], + flags: &BooleanArray, +) -> Result> { + states + .iter() + .map(|state| compute::filter(state, flags).map_err(DataFusionError::ArrowError)) + .collect::>>() +} + +/// Combines array refs and their corresponding orderings to construct `SortColumn`s. +fn convert_to_sort_cols( + arrs: &[ArrayRef], + sort_exprs: &[PhysicalSortExpr], +) -> Vec { + arrs.iter() + .zip(sort_exprs.iter()) + .map(|(item, sort_expr)| SortColumn { + values: item.clone(), + options: Some(sort_expr.options), + }) + .collect::>() +} + +/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. +fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { + ordering_req + .iter() + .map(|item| item.options) + .collect::>() +} + #[cfg(test)] mod tests { use crate::aggregate::first_last::{FirstValueAccumulator, LastValueAccumulator}; + use arrow_array::{ArrayRef, Int64Array}; use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; + use std::sync::Arc; #[test] fn test_first_last_value_value() -> Result<()> { let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[])?; - let mut last_accumulator = LastValueAccumulator::try_new(&DataType::Int64, &[])?; + FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + let mut last_accumulator = + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; // first value in the tuple is start of the range (inclusive), // second value in the tuple is end of the range (exclusive) let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 5490b875763a..69918cfac268 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -49,6 +49,7 @@ pub mod build_in; pub(crate) mod groups_accumulator; mod hyperloglog; pub mod moving_min_max; +pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod sum; diff --git a/datafusion/physical-expr/src/aggregate/regr.rs b/datafusion/physical-expr/src/aggregate/regr.rs new file mode 100644 index 000000000000..1b8a5c6f76de --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/regr.rs @@ -0,0 +1,460 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::Float64Array; +use arrow::{ + array::{ArrayRef, UInt64Array}, + compute::cast, + datatypes::DataType, + datatypes::Field, +}; +use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::Accumulator; + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::format_state_name; + +#[derive(Debug)] +pub struct Regr { + name: String, + regr_type: RegrType, + expr_y: Arc, + expr_x: Arc, +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +pub enum RegrType { + /// Variant for `regr_slope` aggregate expression + /// Returns the slope of the linear regression line for non-null pairs in aggregate columns. + /// Given input column Y and X: `regr_slope(Y, X)` returns the slope (k in Y = k*X + b) using minimal + /// RSS (Residual Sum of Squares) fitting. + Slope, + /// Variant for `regr_intercept` aggregate expression + /// Returns the intercept of the linear regression line for non-null pairs in aggregate columns. + /// Given input column Y and X: `regr_intercept(Y, X)` returns the intercept (b in Y = k*X + b) using minimal + /// RSS fitting. + Intercept, + /// Variant for `regr_count` aggregate expression + /// Returns the number of input rows for which both expressions are not null. + /// Given input column Y and X: `regr_count(Y, X)` returns the count of non-null pairs. + Count, + /// Variant for `regr_r2` aggregate expression + /// Returns the coefficient of determination (R-squared value) of the linear regression line for non-null pairs in aggregate columns. + /// The R-squared value represents the proportion of variance in Y that is predictable from X. + R2, + /// Variant for `regr_avgx` aggregate expression + /// Returns the average of the independent variable for non-null pairs in aggregate columns. + /// Given input column X: `regr_avgx(Y, X)` returns the average of X values. + AvgX, + /// Variant for `regr_avgy` aggregate expression + /// Returns the average of the dependent variable for non-null pairs in aggregate columns. + /// Given input column Y: `regr_avgy(Y, X)` returns the average of Y values. + AvgY, + /// Variant for `regr_sxx` aggregate expression + /// Returns the sum of squares of the independent variable for non-null pairs in aggregate columns. + /// Given input column X: `regr_sxx(Y, X)` returns the sum of squares of deviations of X from its mean. + SXX, + /// Variant for `regr_syy` aggregate expression + /// Returns the sum of squares of the dependent variable for non-null pairs in aggregate columns. + /// Given input column Y: `regr_syy(Y, X)` returns the sum of squares of deviations of Y from its mean. + SYY, + /// Variant for `regr_sxy` aggregate expression + /// Returns the sum of products of pairs of numbers for non-null pairs in aggregate columns. + /// Given input column Y and X: `regr_sxy(Y, X)` returns the sum of products of the deviations of Y and X from their respective means. + SXY, +} + +impl Regr { + pub fn new( + expr_y: Arc, + expr_x: Arc, + name: impl Into, + regr_type: RegrType, + return_type: DataType, + ) -> Self { + // the result of regr_slope only support FLOAT64 data type. + assert!(matches!(return_type, DataType::Float64)); + Self { + name: name.into(), + regr_type, + expr_y, + expr_x, + } + } +} + +impl AggregateExpr for Regr { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Float64, true)) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) + } + + fn create_sliding_accumulator(&self) -> Result> { + Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + format_state_name(&self.name, "mean_x"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean_y"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_x"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_y"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr_y.clone(), self.expr_x.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for Regr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.expr_y.eq(&x.expr_y) + && self.expr_x.eq(&x.expr_x) + }) + .unwrap_or(false) + } +} + +/// `RegrAccumulator` is used to compute linear regression aggregate functions +/// by maintaining statistics needed to compute them in an online fashion. +/// +/// This struct uses Welford's online algorithm for calculating variance and covariance: +/// +/// +/// Given the statistics, the following aggregate functions can be calculated: +/// +/// - `regr_slope(y, x)`: Slope of the linear regression line, calculated as: +/// cov_pop(x, y) / var_pop(x). +/// It represents the expected change in Y for a one-unit change in X. +/// +/// - `regr_intercept(y, x)`: Intercept of the linear regression line, calculated as: +/// mean_y - (regr_slope(y, x) * mean_x). +/// It represents the expected value of Y when X is 0. +/// +/// - `regr_count(y, x)`: Count of the non-null(both x and y) input rows. +/// +/// - `regr_r2(y, x)`: R-squared value (coefficient of determination), calculated as: +/// (cov_pop(x, y) ^ 2) / (var_pop(x) * var_pop(y)). +/// It provides a measure of how well the model's predictions match the observed data. +/// +/// - `regr_avgx(y, x)`: Average of the independent variable X, calculated as: mean_x. +/// +/// - `regr_avgy(y, x)`: Average of the dependent variable Y, calculated as: mean_y. +/// +/// - `regr_sxx(y, x)`: Sum of squares of the independent variable X, calculated as: +/// m2_x. +/// +/// - `regr_syy(y, x)`: Sum of squares of the dependent variable Y, calculated as: +/// m2_y. +/// +/// - `regr_sxy(y, x)`: Sum of products of paired values, calculated as: +/// algo_const. +/// +/// Here's how the statistics maintained in this struct are calculated: +/// - `cov_pop(x, y)`: algo_const / count. +/// - `var_pop(x)`: m2_x / count. +/// - `var_pop(y)`: m2_y / count. +#[derive(Debug)] +pub struct RegrAccumulator { + count: u64, + mean_x: f64, + mean_y: f64, + m2_x: f64, + m2_y: f64, + algo_const: f64, + regr_type: RegrType, +} + +impl RegrAccumulator { + /// Creates a new `RegrAccumulator` + pub fn try_new(regr_type: &RegrType) -> Result { + Ok(Self { + count: 0_u64, + mean_x: 0_f64, + mean_y: 0_f64, + m2_x: 0_f64, + m2_y: 0_f64, + algo_const: 0_f64, + regr_type: regr_type.clone(), + }) + } +} + +impl Accumulator for RegrAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean_x), + ScalarValue::from(self.mean_y), + ScalarValue::from(self.m2_x), + ScalarValue::from(self.m2_y), + ScalarValue::from(self.algo_const), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // regr_slope(Y, X) calculates k in y = k*x + b + let values_y = &cast(&values[0], &DataType::Float64)?; + let values_x = &cast(&values[1], &DataType::Float64)?; + + let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); + let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + + for i in 0..values_y.len() { + // skip either x or y is NULL + let value_y = if values_y.is_valid(i) { + arr_y.next() + } else { + None + }; + let value_x = if values_x.is_valid(i) { + arr_x.next() + } else { + None + }; + if value_y.is_none() || value_x.is_none() { + continue; + } + + // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] + let value_y = unwrap_or_internal_err!(value_y); + let value_x = unwrap_or_internal_err!(value_x); + + self.count += 1; + let delta_x = value_x - self.mean_x; + let delta_y = value_y - self.mean_y; + self.mean_x += delta_x / self.count as f64; + self.mean_y += delta_y / self.count as f64; + let delta_x_2 = value_x - self.mean_x; + let delta_y_2 = value_y - self.mean_y; + self.m2_x += delta_x * delta_x_2; + self.m2_y += delta_y * delta_y_2; + self.algo_const += delta_x * (value_y - self.mean_y); + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values_y = &cast(&values[0], &DataType::Float64)?; + let values_x = &cast(&values[1], &DataType::Float64)?; + + let mut arr_y = downcast_value!(values_y, Float64Array).iter().flatten(); + let mut arr_x = downcast_value!(values_x, Float64Array).iter().flatten(); + + for i in 0..values_y.len() { + // skip either x or y is NULL + let value_y = if values_y.is_valid(i) { + arr_y.next() + } else { + None + }; + let value_x = if values_x.is_valid(i) { + arr_x.next() + } else { + None + }; + if value_y.is_none() || value_x.is_none() { + continue; + } + + // Update states for regr_slope(y,x) [using cov_pop(x,y)/var_pop(x)] + let value_y = unwrap_or_internal_err!(value_y); + let value_x = unwrap_or_internal_err!(value_x); + + if self.count > 1 { + self.count -= 1; + let delta_x = value_x - self.mean_x; + let delta_y = value_y - self.mean_y; + self.mean_x -= delta_x / self.count as f64; + self.mean_y -= delta_y / self.count as f64; + let delta_x_2 = value_x - self.mean_x; + let delta_y_2 = value_y - self.mean_y; + self.m2_x -= delta_x * delta_x_2; + self.m2_y -= delta_y * delta_y_2; + self.algo_const -= delta_x * (value_y - self.mean_y); + } else { + self.count = 0; + self.mean_x = 0.0; + self.m2_x = 0.0; + self.m2_y = 0.0; + self.mean_y = 0.0; + self.algo_const = 0.0; + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let count_arr = downcast_value!(states[0], UInt64Array); + let mean_x_arr = downcast_value!(states[1], Float64Array); + let mean_y_arr = downcast_value!(states[2], Float64Array); + let m2_x_arr = downcast_value!(states[3], Float64Array); + let m2_y_arr = downcast_value!(states[4], Float64Array); + let algo_const_arr = downcast_value!(states[5], Float64Array); + + for i in 0..count_arr.len() { + let count_b = count_arr.value(i); + if count_b == 0_u64 { + continue; + } + let (count_a, mean_x_a, mean_y_a, m2_x_a, m2_y_a, algo_const_a) = ( + self.count, + self.mean_x, + self.mean_y, + self.m2_x, + self.m2_y, + self.algo_const, + ); + let (count_b, mean_x_b, mean_y_b, m2_x_b, m2_y_b, algo_const_b) = ( + count_b, + mean_x_arr.value(i), + mean_y_arr.value(i), + m2_x_arr.value(i), + m2_y_arr.value(i), + algo_const_arr.value(i), + ); + + // Assuming two different batches of input have calculated the states: + // batch A of Y, X -> {count_a, mean_x_a, mean_y_a, m2_x_a, algo_const_a} + // batch B of Y, X -> {count_b, mean_x_b, mean_y_b, m2_x_b, algo_const_b} + // The merged states from A and B are {count_ab, mean_x_ab, mean_y_ab, m2_x_ab, + // algo_const_ab} + // + // Reference for the algorithm to merge states: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + let count_ab = count_a + count_b; + let (count_a, count_b) = (count_a as f64, count_b as f64); + let d_x = mean_x_b - mean_x_a; + let d_y = mean_y_b - mean_y_a; + let mean_x_ab = mean_x_a + d_x * count_b / count_ab as f64; + let mean_y_ab = mean_y_a + d_y * count_b / count_ab as f64; + let m2_x_ab = + m2_x_a + m2_x_b + d_x * d_x * count_a * count_b / count_ab as f64; + let m2_y_ab = + m2_y_a + m2_y_b + d_y * d_y * count_a * count_b / count_ab as f64; + let algo_const_ab = algo_const_a + + algo_const_b + + d_x * d_y * count_a * count_b / count_ab as f64; + + self.count = count_ab; + self.mean_x = mean_x_ab; + self.mean_y = mean_y_ab; + self.m2_x = m2_x_ab; + self.m2_y = m2_y_ab; + self.algo_const = algo_const_ab; + } + Ok(()) + } + + fn evaluate(&self) -> Result { + let cov_pop_x_y = self.algo_const / self.count as f64; + let var_pop_x = self.m2_x / self.count as f64; + let var_pop_y = self.m2_y / self.count as f64; + + let nullif_or_stat = |cond: bool, stat: f64| { + if cond { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(stat))) + } + }; + + match self.regr_type { + RegrType::Slope => { + // Only 0/1 point or slope is infinite + let nullif_cond = self.count <= 1 || var_pop_x == 0.0; + nullif_or_stat(nullif_cond, cov_pop_x_y / var_pop_x) + } + RegrType::Intercept => { + let slope = cov_pop_x_y / var_pop_x; + // Only 0/1 point or slope is infinite + let nullif_cond = self.count <= 1 || var_pop_x == 0.0; + nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x) + } + RegrType::Count => Ok(ScalarValue::Float64(Some(self.count as f64))), + RegrType::R2 => { + // Only 0/1 point or all x(or y) is the same + let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0; + nullif_or_stat( + nullif_cond, + (cov_pop_x_y * cov_pop_x_y) / (var_pop_x * var_pop_y), + ) + } + RegrType::AvgX => nullif_or_stat(self.count < 1, self.mean_x), + RegrType::AvgY => nullif_or_stat(self.count < 1, self.mean_y), + RegrType::SXX => nullif_or_stat(self.count < 1, self.m2_x), + RegrType::SYY => nullif_or_stat(self.count < 1, self.m2_y), + RegrType::SXY => nullif_or_stat(self.count < 1, self.algo_const), + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 6940456657ad..a223a6998a39 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow_buffer::NullBuffer; use core::any::type_name; use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_array}; -use datafusion_common::ScalarValue; +use datafusion_common::{plan_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; use itertools::Itertools; @@ -255,9 +255,7 @@ fn compute_array_dims(arr: Option) -> Result>>> fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { // do not accept 0 arguments. if args.is_empty() { - return Err(DataFusionError::Plan( - "Array requires at least one argument".to_string(), - )); + return plan_err!("Array requires at least one argument"); } let res = match data_type { @@ -363,6 +361,177 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { } } +fn return_empty(return_null: bool, data_type: DataType) -> Arc { + if return_null { + new_null_array(&data_type, 1) + } else { + new_empty_array(&data_type) + } +} + +macro_rules! list_slice { + ($ARRAY:expr, $I:expr, $J:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ + let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap(); + if $I == 0 && $J == 0 || $ARRAY.is_empty() { + return return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()); + } + + let i = if $I < 0 { + if $I.abs() as usize > array.len() { + return return_empty(true, $ARRAY.data_type().clone()); + } + + (array.len() as i64 + $I + 1) as usize + } else { + if $I == 0 { + 1 + } else { + $I as usize + } + }; + let j = if $J < 0 { + if $J.abs() as usize > array.len() { + return return_empty(true, $ARRAY.data_type().clone()); + } + + if $RETURN_ELEMENT { + (array.len() as i64 + $J + 1) as usize + } else { + (array.len() as i64 + $J) as usize + } + } else { + if $J == 0 { + 1 + } else { + if $J as usize > array.len() { + array.len() + } else { + $J as usize + } + } + }; + + if i > j || i as usize > $ARRAY.len() { + return_empty($RETURN_ELEMENT, $ARRAY.data_type().clone()) + } else { + Arc::new(array.slice((i - 1), (j + 1 - i))) + } + }}; +} + +macro_rules! slice { + ($ARRAY:expr, $KEY:expr, $EXTRA_KEY:expr, $RETURN_ELEMENT:expr, $ARRAY_TYPE:ident) => {{ + let sliced_array: Vec> = $ARRAY + .iter() + .zip($KEY.iter()) + .zip($EXTRA_KEY.iter()) + .map(|((arr, i), j)| match (arr, i, j) { + (Some(arr), Some(i), Some(j)) => { + list_slice!(arr, i, j, $RETURN_ELEMENT, $ARRAY_TYPE) + } + (Some(arr), None, Some(j)) => { + list_slice!(arr, 1i64, j, $RETURN_ELEMENT, $ARRAY_TYPE) + } + (Some(arr), Some(i), None) => { + list_slice!(arr, i, arr.len() as i64, $RETURN_ELEMENT, $ARRAY_TYPE) + } + (Some(arr), None, None) if !$RETURN_ELEMENT => arr, + _ => return_empty($RETURN_ELEMENT, $ARRAY.value_type().clone()), + }) + .collect(); + + // concat requires input of at least one array + if sliced_array.is_empty() { + Ok(return_empty($RETURN_ELEMENT, $ARRAY.value_type())) + } else { + let vec = sliced_array + .iter() + .map(|a| a.as_ref()) + .collect::>(); + let mut i: i32 = 0; + let mut offsets = vec![i]; + offsets.extend( + vec.iter() + .map(|a| { + i += a.len() as i32; + i + }) + .collect::>(), + ); + let values = compute::concat(vec.as_slice()).unwrap(); + + if $RETURN_ELEMENT { + Ok(values) + } else { + let field = + Arc::new(Field::new("item", $ARRAY.value_type().clone(), true)); + Ok(Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + values, + None, + )?)) + } + } + }}; +} + +fn define_array_slice( + list_array: &ListArray, + key: &Int64Array, + extra_key: &Int64Array, + return_element: bool, +) -> Result { + match list_array.value_type() { + DataType::List(_) => { + slice!(list_array, key, extra_key, return_element, ListArray) + } + DataType::Utf8 => slice!(list_array, key, extra_key, return_element, StringArray), + DataType::LargeUtf8 => { + slice!(list_array, key, extra_key, return_element, LargeStringArray) + } + DataType::Boolean => { + slice!(list_array, key, extra_key, return_element, BooleanArray) + } + DataType::Float32 => { + slice!(list_array, key, extra_key, return_element, Float32Array) + } + DataType::Float64 => { + slice!(list_array, key, extra_key, return_element, Float64Array) + } + DataType::Int8 => slice!(list_array, key, extra_key, return_element, Int8Array), + DataType::Int16 => slice!(list_array, key, extra_key, return_element, Int16Array), + DataType::Int32 => slice!(list_array, key, extra_key, return_element, Int32Array), + DataType::Int64 => slice!(list_array, key, extra_key, return_element, Int64Array), + DataType::UInt8 => slice!(list_array, key, extra_key, return_element, UInt8Array), + DataType::UInt16 => { + slice!(list_array, key, extra_key, return_element, UInt16Array) + } + DataType::UInt32 => { + slice!(list_array, key, extra_key, return_element, UInt32Array) + } + DataType::UInt64 => { + slice!(list_array, key, extra_key, return_element, UInt64Array) + } + data_type => Err(DataFusionError::NotImplemented(format!( + "array is not implemented for types '{data_type:?}'" + ))), + } +} + +pub fn array_element(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let key = as_int64_array(&args[1])?; + define_array_slice(list_array, key, key, true) +} + +pub fn array_slice(args: &[ArrayRef]) -> Result { + let list_array = as_list_array(&args[0])?; + let key = as_int64_array(&args[1])?; + let extra_key = as_int64_array(&args[2])?; + define_array_slice(list_array, key, extra_key, false) +} + macro_rules! append { ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ let mut offsets: Vec = vec![0]; @@ -657,84 +826,169 @@ pub fn array_concat(args: &[ArrayRef]) -> Result { concat_internal(new_args.as_slice()) } -macro_rules! fill { - ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); +macro_rules! general_repeat { + ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ + let mut offsets: Vec = vec![0]; + let mut values = + downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone(); - let mut acc = ColumnarValue::Scalar($ELEMENT); - for value in arr.iter().rev() { - match value { - Some(value) => { - let mut repeated = vec![]; - for _ in 0..value { - repeated.push(acc.clone()); - } - acc = array(repeated.as_slice()).unwrap(); + let element_array = downcast_arg!($ELEMENT, $ARRAY_TYPE); + for (el, c) in element_array.iter().zip($COUNT.iter()) { + let last_offset: i32 = offsets.last().copied().ok_or_else(|| { + DataFusionError::Internal(format!("offsets should not be empty")) + })?; + match el { + Some(el) => { + let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; + let repeated_array = + [Some(el.clone())].repeat(c).iter().collect::<$ARRAY_TYPE>(); + + values = downcast_arg!( + compute::concat(&[&values, &repeated_array])?.clone(), + $ARRAY_TYPE + ) + .clone(); + offsets.push(last_offset + repeated_array.len() as i32); } - _ => { - return Err(DataFusionError::Internal(format!( - "Array_fill function requires non nullable array" - ))); + None => { + offsets.push(last_offset); } } } - acc + let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + + Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + )?) }}; } -/// Array_fill SQL function -pub fn array_fill(args: &[ColumnarValue]) -> Result { - if args.len() != 2 { - return Err(DataFusionError::Internal(format!( - "Array_fill function requires two arguments, got {}", - args.len() - ))); - } +macro_rules! general_repeat_list { + ($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{ + let mut offsets: Vec = vec![0]; + let mut values = + downcast_arg!(new_empty_array($ELEMENT.data_type()), ListArray).clone(); - let element = match &args[0] { - ColumnarValue::Scalar(scalar) => scalar.clone(), - _ => { - return Err(DataFusionError::Internal( - "Array_fill function requires scalar element".to_string(), - )) - } - }; + let element_array = downcast_arg!($ELEMENT, ListArray); + for (el, c) in element_array.iter().zip($COUNT.iter()) { + let last_offset: i32 = offsets.last().copied().ok_or_else(|| { + DataFusionError::Internal(format!("offsets should not be empty")) + })?; + match el { + Some(el) => { + let c = if c < Some(0) { 0 } else { c.unwrap() } as usize; + let repeated_vec = vec![el; c]; - let arr = match &args[1] { - ColumnarValue::Scalar(scalar) => scalar.to_array().clone(), - ColumnarValue::Array(arr) => arr.clone(), - }; + let mut i: i32 = 0; + let mut repeated_offsets = vec![i]; + repeated_offsets.extend( + repeated_vec + .clone() + .into_iter() + .map(|a| { + i += a.len() as i32; + i + }) + .collect::>(), + ); - let res = match arr.data_type() { - DataType::List(..) => { - let arr = downcast_arg!(arr, ListArray); - let array_values = arr.values(); - match arr.value_type() { - DataType::Int8 => fill!(array_values, element, Int8Array), - DataType::Int16 => fill!(array_values, element, Int16Array), - DataType::Int32 => fill!(array_values, element, Int32Array), - DataType::Int64 => fill!(array_values, element, Int64Array), - DataType::UInt8 => fill!(array_values, element, UInt8Array), - DataType::UInt16 => fill!(array_values, element, UInt16Array), - DataType::UInt32 => fill!(array_values, element, UInt32Array), - DataType::UInt64 => fill!(array_values, element, UInt64Array), - DataType::Null => { - return Ok(datafusion_expr::ColumnarValue::Scalar( - ScalarValue::new_list(Some(vec![]), DataType::Null), - )) + let mut repeated_values = downcast_arg!( + new_empty_array(&element_array.value_type()), + $ARRAY_TYPE + ) + .clone(); + for repeated_list in repeated_vec { + repeated_values = downcast_arg!( + compute::concat(&[&repeated_values, &repeated_list])?, + $ARRAY_TYPE + ) + .clone(); + } + + let field = Arc::new(Field::new( + "item", + element_array.value_type().clone(), + true, + )); + let repeated_array = ListArray::try_new( + field, + OffsetBuffer::new(repeated_offsets.clone().into()), + Arc::new(repeated_values), + None, + )?; + + values = downcast_arg!( + compute::concat(&[&values, &repeated_array,])?.clone(), + ListArray + ) + .clone(); + offsets.push(last_offset + repeated_array.len() as i32); } - data_type => { - return Err(DataFusionError::Internal(format!( - "Array_fill is not implemented for type '{data_type:?}'." - ))); + None => { + offsets.push(last_offset); } } } + + let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true)); + + Arc::new(ListArray::try_new( + field, + OffsetBuffer::new(offsets.into()), + Arc::new(values), + None, + )?) + }}; +} + +/// Array_repeat SQL function +pub fn array_repeat(args: &[ArrayRef]) -> Result { + let element = &args[0]; + let count = as_int64_array(&args[1])?; + + let res = match element.data_type() { + DataType::List(field) => match field.data_type() { + DataType::List(_) => general_repeat_list!(element, count, ListArray), + DataType::Utf8 => general_repeat_list!(element, count, StringArray), + DataType::LargeUtf8 => general_repeat_list!(element, count, LargeStringArray), + DataType::Boolean => general_repeat_list!(element, count, BooleanArray), + DataType::Float32 => general_repeat_list!(element, count, Float32Array), + DataType::Float64 => general_repeat_list!(element, count, Float64Array), + DataType::Int8 => general_repeat_list!(element, count, Int8Array), + DataType::Int16 => general_repeat_list!(element, count, Int16Array), + DataType::Int32 => general_repeat_list!(element, count, Int32Array), + DataType::Int64 => general_repeat_list!(element, count, Int64Array), + DataType::UInt8 => general_repeat_list!(element, count, UInt8Array), + DataType::UInt16 => general_repeat_list!(element, count, UInt16Array), + DataType::UInt32 => general_repeat_list!(element, count, UInt32Array), + DataType::UInt64 => general_repeat_list!(element, count, UInt64Array), + data_type => { + return Err(DataFusionError::NotImplemented(format!( + "Array_repeat is not implemented for types 'List({data_type:?})'." + ))) + } + }, + DataType::Utf8 => general_repeat!(element, count, StringArray), + DataType::LargeUtf8 => general_repeat!(element, count, LargeStringArray), + DataType::Boolean => general_repeat!(element, count, BooleanArray), + DataType::Float32 => general_repeat!(element, count, Float32Array), + DataType::Float64 => general_repeat!(element, count, Float64Array), + DataType::Int8 => general_repeat!(element, count, Int8Array), + DataType::Int16 => general_repeat!(element, count, Int16Array), + DataType::Int32 => general_repeat!(element, count, Int32Array), + DataType::Int64 => general_repeat!(element, count, Int64Array), + DataType::UInt8 => general_repeat!(element, count, UInt8Array), + DataType::UInt16 => general_repeat!(element, count, UInt16Array), + DataType::UInt32 => general_repeat!(element, count, UInt32Array), + DataType::UInt64 => general_repeat!(element, count, UInt64Array), data_type => { - return Err(DataFusionError::Internal(format!( - "Array is not type '{data_type:?}'." - ))); + return Err(DataFusionError::NotImplemented(format!( + "Array_repeat is not implemented for types '{data_type:?}'." + ))) } }; @@ -795,31 +1049,24 @@ pub fn array_position(args: &[ArrayRef]) -> Result { Int64Array::from_value(0, arr.len()) }; - let res = match arr.data_type() { - DataType::List(field) => match field.data_type() { - DataType::List(_) => position!(arr, element, index, ListArray), - DataType::Utf8 => position!(arr, element, index, StringArray), - DataType::LargeUtf8 => position!(arr, element, index, LargeStringArray), - DataType::Boolean => position!(arr, element, index, BooleanArray), - DataType::Float32 => position!(arr, element, index, Float32Array), - DataType::Float64 => position!(arr, element, index, Float64Array), - DataType::Int8 => position!(arr, element, index, Int8Array), - DataType::Int16 => position!(arr, element, index, Int16Array), - DataType::Int32 => position!(arr, element, index, Int32Array), - DataType::Int64 => position!(arr, element, index, Int64Array), - DataType::UInt8 => position!(arr, element, index, UInt8Array), - DataType::UInt16 => position!(arr, element, index, UInt16Array), - DataType::UInt32 => position!(arr, element, index, UInt32Array), - DataType::UInt64 => position!(arr, element, index, UInt64Array), - data_type => { - return Err(DataFusionError::NotImplemented(format!( - "Array_position is not implemented for types '{data_type:?}'." - ))) - } - }, + let res = match arr.value_type() { + DataType::List(_) => position!(arr, element, index, ListArray), + DataType::Utf8 => position!(arr, element, index, StringArray), + DataType::LargeUtf8 => position!(arr, element, index, LargeStringArray), + DataType::Boolean => position!(arr, element, index, BooleanArray), + DataType::Float32 => position!(arr, element, index, Float32Array), + DataType::Float64 => position!(arr, element, index, Float64Array), + DataType::Int8 => position!(arr, element, index, Int8Array), + DataType::Int16 => position!(arr, element, index, Int16Array), + DataType::Int32 => position!(arr, element, index, Int32Array), + DataType::Int64 => position!(arr, element, index, Int64Array), + DataType::UInt8 => position!(arr, element, index, UInt8Array), + DataType::UInt16 => position!(arr, element, index, UInt16Array), + DataType::UInt32 => position!(arr, element, index, UInt32Array), + DataType::UInt64 => position!(arr, element, index, UInt64Array), data_type => { return Err(DataFusionError::NotImplemented(format!( - "Array is not type '{data_type:?}'." + "Array_position is not implemented for types '{data_type:?}'." ))) } }; @@ -881,31 +1128,24 @@ pub fn array_positions(args: &[ArrayRef]) -> Result { let arr = as_list_array(&args[0])?; let element = &args[1]; - let res = match arr.data_type() { - DataType::List(field) => match field.data_type() { - DataType::List(_) => positions!(arr, element, ListArray), - DataType::Utf8 => positions!(arr, element, StringArray), - DataType::LargeUtf8 => positions!(arr, element, LargeStringArray), - DataType::Boolean => positions!(arr, element, BooleanArray), - DataType::Float32 => positions!(arr, element, Float32Array), - DataType::Float64 => positions!(arr, element, Float64Array), - DataType::Int8 => positions!(arr, element, Int8Array), - DataType::Int16 => positions!(arr, element, Int16Array), - DataType::Int32 => positions!(arr, element, Int32Array), - DataType::Int64 => positions!(arr, element, Int64Array), - DataType::UInt8 => positions!(arr, element, UInt8Array), - DataType::UInt16 => positions!(arr, element, UInt16Array), - DataType::UInt32 => positions!(arr, element, UInt32Array), - DataType::UInt64 => positions!(arr, element, UInt64Array), - data_type => { - return Err(DataFusionError::NotImplemented(format!( - "Array_positions is not implemented for types '{data_type:?}'." - ))) - } - }, + let res = match arr.value_type() { + DataType::List(_) => positions!(arr, element, ListArray), + DataType::Utf8 => positions!(arr, element, StringArray), + DataType::LargeUtf8 => positions!(arr, element, LargeStringArray), + DataType::Boolean => positions!(arr, element, BooleanArray), + DataType::Float32 => positions!(arr, element, Float32Array), + DataType::Float64 => positions!(arr, element, Float64Array), + DataType::Int8 => positions!(arr, element, Int8Array), + DataType::Int16 => positions!(arr, element, Int16Array), + DataType::Int32 => positions!(arr, element, Int32Array), + DataType::Int64 => positions!(arr, element, Int64Array), + DataType::UInt8 => positions!(arr, element, UInt8Array), + DataType::UInt16 => positions!(arr, element, UInt16Array), + DataType::UInt32 => positions!(arr, element, UInt32Array), + DataType::UInt64 => positions!(arr, element, UInt64Array), data_type => { return Err(DataFusionError::NotImplemented(format!( - "Array is not type '{data_type:?}'." + "Array_positions is not implemented for types '{data_type:?}'." ))) } }; @@ -1483,25 +1723,6 @@ pub fn array_to_string(args: &[ArrayRef]) -> Result { Ok(Arc::new(StringArray::from(res))) } -/// Trim_array SQL function -pub fn trim_array(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let n = as_int64_array(&args[1])?.value(0) as usize; - - let values = list_array.value(0); - if values.len() <= n { - return Ok(array(&[ColumnarValue::Scalar(ScalarValue::Null)])?.into_array(1)); - } - - let res = values.slice(0, values.len() - n); - let mut scalars = vec![]; - for i in 0..res.len() { - scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&res, i)?)); - } - - Ok(array(scalars.as_slice())?.into_array(1)) -} - /// Cardinality SQL function pub fn cardinality(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?.clone(); @@ -1952,6 +2173,364 @@ mod tests { ) } + #[test] + fn test_array_element() { + // array_element([1, 2, 3, 4], 1) = 1 + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(1, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from_value(1, 1)); + + // array_element([1, 2, 3, 4], 3) = 3 + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(3, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from_value(3, 1)); + + // array_element([1, 2, 3, 4], 0) = NULL + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(0, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from(vec![None])); + + // array_element([1, 2, 3, 4], NULL) = NULL + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from(vec![None]))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from(vec![None])); + + // array_element([1, 2, 3, 4], -1) = 4 + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-1, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from_value(4, 1)); + + // array_element([1, 2, 3, 4], -3) = 2 + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(-3, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from_value(2, 1)); + + // array_element([1, 2, 3, 4], 10) = NULL + let list_array = return_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(10, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_int64_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!(result, &Int64Array::from(vec![None])); + } + + #[test] + fn test_nested_array_element() { + // array_element([[1, 2, 3, 4], [5, 6, 7, 8]], 2) = [5, 6, 7, 8] + let list_array = return_nested_array().into_array(1); + let arr = array_element(&[list_array, Arc::new(Int64Array::from_value(2, 1))]) + .expect("failed to initialize function array_element"); + let result = + as_list_array(&arr).expect("failed to initialize function array_element"); + + assert_eq!( + &[5, 6, 7, 8], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + + #[test] + fn test_array_slice() { + // array_slice([1, 2, 3, 4], 1, 3) = [1, 2, 3] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(1, 1)), + Arc::new(Int64Array::from_value(3, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[1, 2, 3], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], 2, 2) = [2] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(2, 1)), + Arc::new(Int64Array::from_value(2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[2], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], 0, 0) = [] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(0, 1)), + Arc::new(Int64Array::from_value(0, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_empty()); + + // array_slice([1, 2, 3, 4], 0, 6) = [1, 2, 3, 4] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(0, 1)), + Arc::new(Int64Array::from_value(6, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[1, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], -2, -2) = [] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-2, 1)), + Arc::new(Int64Array::from_value(-2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_empty()); + + // array_slice([1, 2, 3, 4], -3, -1) = [2, 3] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-3, 1)), + Arc::new(Int64Array::from_value(-1, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[2, 3], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], -3, 2) = [2] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-3, 1)), + Arc::new(Int64Array::from_value(2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[2], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], 2, 11) = [2, 3, 4] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(2, 1)), + Arc::new(Int64Array::from_value(11, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([1, 2, 3, 4], 3, 1) = [] + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(3, 1)), + Arc::new(Int64Array::from_value(1, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_empty()); + + // array_slice([1, 2, 3, 4], -7, -2) = NULL + let list_array = return_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-7, 1)), + Arc::new(Int64Array::from_value(-2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_null(0)); + } + + #[test] + fn test_nested_array_slice() { + // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], 1, 1) = [[1, 2, 3, 4]] + let list_array = return_nested_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(1, 1)), + Arc::new(Int64Array::from_value(1, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[1, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, -1) = [] + let list_array = return_nested_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-1, 1)), + Arc::new(Int64Array::from_value(-1, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert!(result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .is_empty()); + + // array_slice([[1, 2, 3, 4], [5, 6, 7, 8]], -1, 2) = [[5, 6, 7, 8]] + let list_array = return_nested_array().into_array(1); + let arr = array_slice(&[ + list_array, + Arc::new(Int64Array::from_value(-1, 1)), + Arc::new(Int64Array::from_value(2, 1)), + ]) + .expect("failed to initialize function array_slice"); + let result = + as_list_array(&arr).expect("failed to initialize function array_slice"); + + assert_eq!( + &[5, 6, 7, 8], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + #[test] fn test_array_append() { // array_append([1, 2, 3], 4) = [1, 2, 3, 4] @@ -2077,35 +2656,6 @@ mod tests { ); } - #[test] - fn test_array_fill() { - // array_fill(4, [5]) = [4, 4, 4, 4, 4] - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ColumnarValue::Scalar(ScalarValue::List( - Some(vec![ScalarValue::Int64(Some(5))]), - Arc::new(Field::new("item", DataType::Int64, false)), - )), - ]; - - let array = array_fill(&args) - .expect("failed to initialize function array_fill") - .into_array(1); - let result = - as_list_array(&array).expect("failed to initialize function array_fill"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[4, 4, 4, 4, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - #[test] fn test_array_position() { // array_position([1, 2, 3, 4], 3) = 3 @@ -2494,6 +3044,55 @@ mod tests { ); } + #[test] + fn test_array_repeat() { + // array_repeat(3, 5) = [3, 3, 3, 3, 3] + let array = array_repeat(&[ + Arc::new(Int64Array::from_value(3, 1)), + Arc::new(Int64Array::from_value(5, 1)), + ]) + .expect("failed to initialize function array_repeat"); + let result = + as_list_array(&array).expect("failed to initialize function array_repeat"); + + assert_eq!(result.len(), 1); + assert_eq!( + &[3, 3, 3, 3, 3], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + + #[test] + fn test_nested_array_repeat() { + // array_repeat([1, 2, 3, 4], 3) = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] + let element = return_array().into_array(1); + let array = array_repeat(&[element, Arc::new(Int64Array::from_value(3, 1))]) + .expect("failed to initialize function array_repeat"); + let result = + as_list_array(&array).expect("failed to initialize function array_repeat"); + + assert_eq!(result.len(), 1); + let data = vec![ + Some(vec![Some(1), Some(2), Some(3), Some(4)]), + Some(vec![Some(1), Some(2), Some(3), Some(4)]), + Some(vec![Some(1), Some(2), Some(3), Some(4)]), + ]; + let expected = ListArray::from_iter_primitive::(data); + assert_eq!( + expected, + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .clone() + ); + } #[test] fn test_array_to_string() { // array_to_string([1, 2, 3, 4], ',') = 1,2,3,4 @@ -2550,69 +3149,6 @@ mod tests { assert_eq!("1-*-3-*-*-6-7-*", result.value(0)); } - #[test] - fn test_trim_array() { - // trim_array([1, 2, 3, 4], 1) = [1, 2, 3] - let list_array = return_array().into_array(1); - let arr = trim_array(&[list_array, Arc::new(Int64Array::from(vec![Some(1)]))]) - .expect("failed to initialize function trim_array"); - let result = - as_list_array(&arr).expect("failed to initialize function trim_array"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // trim_array([1, 2, 3, 4], 3) = [1] - let list_array = return_array().into_array(1); - let arr = trim_array(&[list_array, Arc::new(Int64Array::from(vec![Some(3)]))]) - .expect("failed to initialize function trim_array"); - let result = - as_list_array(&arr).expect("failed to initialize function trim_array"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_nested_trim_array() { - // trim_array([[1, 2, 3, 4], [5, 6, 7, 8]], 1) = [[1, 2, 3, 4]] - let list_array = return_nested_array().into_array(1); - let arr = trim_array(&[list_array, Arc::new(Int64Array::from(vec![Some(1)]))]) - .expect("failed to initialize function trim_array"); - let binding = as_list_array(&arr) - .expect("failed to initialize function trim_array") - .value(0); - let result = - as_list_array(&binding).expect("failed to initialize function trim_array"); - - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - #[test] fn test_cardinality() { // cardinality([1, 2, 3, 4]) = 4 diff --git a/datafusion/physical-expr/src/crypto_expressions.rs b/datafusion/physical-expr/src/crypto_expressions.rs index c940933b102f..3c2a095361c6 100644 --- a/datafusion/physical-expr/src/crypto_expressions.rs +++ b/datafusion/physical-expr/src/crypto_expressions.rs @@ -23,10 +23,11 @@ use arrow::{ }; use blake2::{Blake2b512, Blake2s256, Digest}; use blake3::Hasher as Blake3; -use datafusion_common::cast::{ - as_binary_array, as_generic_binary_array, as_generic_string_array, -}; use datafusion_common::ScalarValue; +use datafusion_common::{ + cast::{as_binary_array, as_generic_binary_array, as_generic_string_array}, + plan_err, +}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; use md5::Md5; @@ -224,9 +225,9 @@ impl FromStr for DigestAlgorithm { .map(|i| i.to_string()) .collect::>() .join(", "); - return Err(DataFusionError::Plan(format!( - "There is no built-in digest algorithm named '{name}', currently supported algorithms are: {options}", - ))); + return plan_err!( + "There is no built-in digest algorithm named '{name}', currently supported algorithms are: {options}" + ); } }) } diff --git a/datafusion/physical-expr/src/encoding_expressions.rs b/datafusion/physical-expr/src/encoding_expressions.rs index e8b4331e9298..88d1bec70fda 100644 --- a/datafusion/physical-expr/src/encoding_expressions.rs +++ b/datafusion/physical-expr/src/encoding_expressions.rs @@ -22,8 +22,11 @@ use arrow::{ datatypes::DataType, }; use base64::{engine::general_purpose, Engine as _}; -use datafusion_common::cast::{as_generic_binary_array, as_generic_string_array}; use datafusion_common::ScalarValue; +use datafusion_common::{ + cast::{as_generic_binary_array, as_generic_string_array}, + plan_err, +}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -279,9 +282,9 @@ impl FromStr for Encoding { .map(|i| i.to_string()) .collect::>() .join(", "); - return Err(DataFusionError::Plan(format!( - "There is no built-in encoding named '{name}', currently supported encodings are: {options}", - ))); + return plan_err!( + "There is no built-in encoding named '{name}', currently supported encodings are: {options}" + ); } }) } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index d2fc4600b906..34633f6e1dc3 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -86,6 +86,7 @@ use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use datafusion_common::cast::as_boolean_array; +use datafusion_common::plan_err; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::type_coercion::binary::{ @@ -1124,9 +1125,9 @@ pub fn binary( if (is_utf8_or_large_utf8(lhs_type) && is_timestamp(rhs_type)) || (is_timestamp(lhs_type) && is_utf8_or_large_utf8(rhs_type)) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The type of {lhs_type} {op:?} {rhs_type} of binary physical should be same" - ))); + ); } if !lhs_type.eq(rhs_type) && (!is_decimal(lhs_type) && !is_decimal(rhs_type)) { return Err(DataFusionError::Internal(format!( diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 91fa9bbb9309..506c01b6f371 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -404,6 +404,7 @@ mod tests { use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; use datafusion_common::cast::{as_float64_array, as_int32_array}; + use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::binary::comparison_coercion; @@ -966,9 +967,9 @@ mod tests { let coerce_type = get_case_common_type(&when_thens, else_expr.clone(), input_schema); let (when_thens, else_expr) = match coerce_type { - None => Err(DataFusionError::Plan(format!( + None => plan_err!( "Can't get a common type for then {when_thens:?} and else {else_expr:?} expression" - ))), + ), Some(data_type) => { // cast then expr let left = when_thens diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 3b0d77b30431..f40a05c49ae6 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -28,6 +28,7 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; +use datafusion_common::plan_err; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; @@ -175,9 +176,7 @@ impl PhysicalExpr for UnKnownColumn { /// Evaluate the expression fn evaluate(&self, _batch: &RecordBatch) -> Result { - Err(DataFusionError::Plan( - "UnKnownColumn::evaluate() should not be called".to_owned(), - )) + plan_err!("UnKnownColumn::evaluate() should not be called") } fn children(&self) -> Vec> { diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 090cfe5a6e64..596b8f414f06 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -19,52 +19,124 @@ use crate::PhysicalExpr; use arrow::array::Array; -use arrow::compute::concat; +use crate::array_expressions::{array_element, array_slice}; use crate::physical_expr::down_cast_any_ref; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::cast::{as_list_array, as_struct_array}; -use datafusion_common::DataFusionError; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::{ - field_util::get_indexed_field as get_data_type_field, ColumnarValue, -}; -use std::convert::TryInto; +use datafusion_common::{cast::as_struct_array, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue}; use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -/// expression to get a field of a struct array. +/// Access a sub field of a nested type, such as `Field` or `List` +#[derive(Clone, Hash, Debug)] +pub enum GetFieldAccessExpr { + /// Named field, For example `struct["name"]` + NamedStructField { name: ScalarValue }, + /// Single list index, for example: `list[i]` + ListIndex { key: Arc }, + /// List range, for example `list[i:j]` + ListRange { + start: Arc, + stop: Arc, + }, +} + +impl std::fmt::Display for GetFieldAccessExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]", name), + GetFieldAccessExpr::ListIndex { key } => write!(f, "[{}]", key), + GetFieldAccessExpr::ListRange { start, stop } => { + write!(f, "[{}:{}]", start, stop) + } + } + } +} + +impl PartialEq for GetFieldAccessExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.eq(x)) + .unwrap_or(false) + } +} + +/// Expression to get a field of a struct array. #[derive(Debug, Hash)] pub struct GetIndexedFieldExpr { + /// The expression to find arg: Arc, - key: ScalarValue, + /// The key statement + field: GetFieldAccessExpr, } impl GetIndexedFieldExpr { - /// Create new get field expression - pub fn new(arg: Arc, key: ScalarValue) -> Self { - Self { arg, key } + /// Create new [`GetIndexedFieldExpr`] + pub fn new(arg: Arc, field: GetFieldAccessExpr) -> Self { + Self { arg, field } + } + + /// Create a new [`GetIndexedFieldExpr`] for accessing the named field + pub fn new_field(arg: Arc, name: impl Into) -> Self { + Self::new( + arg, + GetFieldAccessExpr::NamedStructField { + name: ScalarValue::Utf8(Some(name.into())), + }, + ) + } + + /// Create a new [`GetIndexedFieldExpr`] for accessing the specified index + pub fn new_index(arg: Arc, key: Arc) -> Self { + Self::new(arg, GetFieldAccessExpr::ListIndex { key }) + } + + /// Create a new [`GetIndexedFieldExpr`] for accessing the range + pub fn new_range( + arg: Arc, + start: Arc, + stop: Arc, + ) -> Self { + Self::new(arg, GetFieldAccessExpr::ListRange { start, stop }) } - /// Get the input key - pub fn key(&self) -> &ScalarValue { - &self.key + /// Get the description of what field should be accessed + pub fn field(&self) -> &GetFieldAccessExpr { + &self.field } /// Get the input expression pub fn arg(&self) -> &Arc { &self.arg } + + fn schema_access(&self, input_schema: &Schema) -> Result { + Ok(match &self.field { + GetFieldAccessExpr::NamedStructField { name } => { + GetFieldAccessSchema::NamedStructField { name: name.clone() } + } + GetFieldAccessExpr::ListIndex { key } => GetFieldAccessSchema::ListIndex { + key_dt: key.data_type(input_schema)?, + }, + GetFieldAccessExpr::ListRange { start, stop } => { + GetFieldAccessSchema::ListRange { + start_dt: start.data_type(input_schema)?, + stop_dt: stop.data_type(input_schema)?, + } + } + }) + } } impl std::fmt::Display for GetIndexedFieldExpr { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "({}).[{}]", self.arg, self.key) + write!(f, "({}).{}", self.arg, self.field) } } @@ -74,70 +146,67 @@ impl PhysicalExpr for GetIndexedFieldExpr { } fn data_type(&self, input_schema: &Schema) -> Result { - let data_type = self.arg.data_type(input_schema)?; - get_data_type_field(&data_type, &self.key).map(|f| f.data_type().clone()) + let arg_dt = self.arg.data_type(input_schema)?; + self.schema_access(input_schema)? + .get_accessed_field(&arg_dt) + .map(|f| f.data_type().clone()) } fn nullable(&self, input_schema: &Schema) -> Result { - let data_type = self.arg.data_type(input_schema)?; - get_data_type_field(&data_type, &self.key).map(|f| f.is_nullable()) + let arg_dt = self.arg.data_type(input_schema)?; + self.schema_access(input_schema)? + .get_accessed_field(&arg_dt) + .map(|f| f.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result { - let array = self.arg.evaluate(batch)?.into_array(1); - match (array.data_type(), &self.key) { - (DataType::List(_) | DataType::Struct(_), _) if self.key.is_null() => { - let scalar_null: ScalarValue = array.data_type().try_into()?; - Ok(ColumnarValue::Scalar(scalar_null)) - } - (DataType::List(lst), ScalarValue::Int64(Some(i))) => { - let as_list_array = as_list_array(&array)?; - - if *i < 1 || as_list_array.is_empty() { - let scalar_null: ScalarValue = lst.data_type().try_into()?; - return Ok(ColumnarValue::Scalar(scalar_null)) + let array = self.arg.evaluate(batch)?.into_array(batch.num_rows()); + match &self.field { + GetFieldAccessExpr::NamedStructField{name} => match (array.data_type(), name) { + (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { + let as_struct_array = as_struct_array(&array)?; + match as_struct_array.column_by_name(k) { + None => Err(DataFusionError::Execution( + format!("get indexed field {k} not found in struct"))), + Some(col) => Ok(ColumnarValue::Array(col.clone())) + } } - - let sliced_array: Vec> = as_list_array - .iter() - .filter_map(|o| match o { - Some(list) => if *i as usize > list.len() { - None - } else { - Some(list.slice((*i -1) as usize, 1)) - }, - None => None - }) - .collect(); - - // concat requires input of at least one array - if sliced_array.is_empty() { - let scalar_null: ScalarValue = lst.data_type().try_into()?; - Ok(ColumnarValue::Scalar(scalar_null)) - } else { - let vec = sliced_array.iter().map(|a| a.as_ref()).collect::>(); - let iter = concat(vec.as_slice()).unwrap(); - - Ok(ColumnarValue::Array(iter)) - } - } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { - let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(k) { - None => Err(DataFusionError::Execution( - format!("get indexed field {k} not found in struct"))), - Some(col) => Ok(ColumnarValue::Array(col.clone())) + (DataType::Struct(_), name) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on struct with utf8 indexes. \ + Tried with {name:?} index"))), + (dt, name) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {name:?} index"))), + }, + GetFieldAccessExpr::ListIndex{key} => { + let key = key.evaluate(batch)?.into_array(batch.num_rows()); + match (array.data_type(), key.data_type()) { + (DataType::List(_), DataType::Int64) => Ok(ColumnarValue::Array(array_element(&[ + array, key + ])?)), + (DataType::List(_), key) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes. \ + Tried with {key:?} index"))), + (dt, key) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {key:?} index"))), + } + }, + GetFieldAccessExpr::ListRange{start, stop} => { + let start = start.evaluate(batch)?.into_array(batch.num_rows()); + let stop = stop.evaluate(batch)?.into_array(batch.num_rows()); + match (array.data_type(), start.data_type(), stop.data_type()) { + (DataType::List(_), DataType::Int64, DataType::Int64) => Ok(ColumnarValue::Array(array_slice(&[ + array, start, stop + ])?)), + (DataType::List(_), start, stop) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes. \ + Tried with {start:?} and {stop:?} indices"))), + (dt, start, stop) => Err(DataFusionError::Execution( + format!("get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {start:?} and {stop:?} indices"))), } - } - (DataType::List(_), key) => Err(DataFusionError::Execution( - format!("get indexed field is only possible on lists with int64 indexes. \ - Tried with {key:?} index"))), - (DataType::Struct(_), key) => Err(DataFusionError::Execution( - format!("get indexed field is only possible on struct with utf8 indexes. \ - Tried with {key:?} index"))), - (dt, key) => Err(DataFusionError::Execution( - format!("get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {key:?} index"))), + }, } } @@ -151,7 +220,7 @@ impl PhysicalExpr for GetIndexedFieldExpr { ) -> Result> { Ok(Arc::new(GetIndexedFieldExpr::new( children[0].clone(), - self.key.clone(), + self.field.clone(), ))) } @@ -165,7 +234,7 @@ impl PartialEq for GetIndexedFieldExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| self.arg.eq(&x.arg) && self.key == x.key) + .map(|x| self.arg.eq(&x.arg) && self.field.eq(&x.field)) .unwrap_or(false) } } @@ -173,301 +242,196 @@ impl PartialEq for GetIndexedFieldExpr { #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit}; - use arrow::array::{ArrayRef, Float64Array, GenericListArray, PrimitiveBuilder}; + use crate::expressions::col; + use arrow::array::new_empty_array; + use arrow::array::{ArrayRef, GenericListArray}; use arrow::array::{ - Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray, StructBuilder, + BooleanArray, Int64Array, ListBuilder, StringBuilder, StructArray, }; - use arrow::datatypes::{Float64Type, Int64Type}; + use arrow::datatypes::Fields; use arrow::{array::StringArray, datatypes::Field}; - use datafusion_common::cast::{as_int64_array, as_string_array}; + use datafusion_common::cast::{as_boolean_array, as_list_array, as_string_array}; use datafusion_common::Result; - fn build_utf8_lists(list_of_lists: Vec>>) -> GenericListArray { + fn build_list_arguments( + list_of_lists: Vec>>, + list_of_start_indices: Vec>, + list_of_stop_indices: Vec>, + ) -> (GenericListArray, Int64Array, Int64Array) { let builder = StringBuilder::with_capacity(list_of_lists.len(), 1024); - let mut lb = ListBuilder::new(builder); + let mut list_builder = ListBuilder::new(builder); for values in list_of_lists { - let builder = lb.values(); + let builder = list_builder.values(); for value in values { match value { None => builder.append_null(), Some(v) => builder.append_value(v), } } - lb.append(true); + list_builder.append(true); } - lb.finish() + let start_array = Int64Array::from(list_of_start_indices); + let stop_array = Int64Array::from(list_of_stop_indices); + (list_builder.finish(), start_array, stop_array) } - fn get_indexed_field_test( - list_of_lists: Vec>>, - index: i64, - expected: Vec>, - ) -> Result<()> { - let schema = list_schema("l"); - let list_col = build_utf8_lists(list_of_lists); - let expr = col("l", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_col)])?; - let key = ScalarValue::Int64(Some(index)); - let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = as_string_array(&result).expect("failed to downcast to StringArray"); - let expected = &StringArray::from(expected); - assert_eq!(expected, result); + #[test] + fn get_indexed_field_named_struct_field() -> Result<()> { + let schema = struct_schema(); + let boolean = BooleanArray::from(vec![false, false, true, true]); + let int = Int64Array::from(vec![42, 28, 19, 31]); + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", DataType::Boolean, true)), + Arc::new(boolean.clone()) as ArrayRef, + ), + ( + Arc::new(Field::new("b", DataType::Int64, true)), + Arc::new(int) as ArrayRef, + ), + ]); + let expr = col("str", &schema).unwrap(); + // only one row should be processed + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; + let expr = Arc::new(GetIndexedFieldExpr::new_field(expr, "a")); + let result = expr.evaluate(&batch)?.into_array(1); + let result = + as_boolean_array(&result).expect("failed to downcast to BooleanArray"); + assert_eq!(boolean, result.clone()); Ok(()) } - fn list_schema(col: &str) -> Schema { - Schema::new(vec![Field::new_list( - col, - Field::new("item", DataType::Utf8, true), + fn struct_schema() -> Schema { + Schema::new(vec![Field::new_struct( + "str", + Fields::from(vec![ + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Int64, true), + ]), true, )]) } + fn list_schema(cols: &[&str]) -> Schema { + if cols.len() == 2 { + Schema::new(vec![ + Field::new_list(cols[0], Field::new("item", DataType::Utf8, true), true), + Field::new(cols[1], DataType::Int64, true), + ]) + } else { + Schema::new(vec![ + Field::new_list(cols[0], Field::new("item", DataType::Utf8, true), true), + Field::new(cols[1], DataType::Int64, true), + Field::new(cols[2], DataType::Int64, true), + ]) + } + } + #[test] - fn get_indexed_field_list() -> Result<()> { + fn get_indexed_field_list_index() -> Result<()> { let list_of_lists = vec![ vec![Some("a"), Some("b"), None], vec![None, Some("c"), Some("d")], vec![Some("e"), None, Some("f")], ]; - let expected_list = vec![ - vec![Some("a"), None, Some("e")], - vec![Some("b"), Some("c"), None], - vec![None, Some("d"), Some("f")], - ]; - - for (i, expected) in expected_list.into_iter().enumerate() { - get_indexed_field_test(list_of_lists.clone(), (i + 1) as i64, expected)?; - } - Ok(()) - } - - #[test] - fn get_indexed_field_empty_list() -> Result<()> { - let schema = list_schema("l"); - let builder = StringBuilder::new(); - let mut lb = ListBuilder::new(builder); - let expr = col("l", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; - let key = ScalarValue::Int64(Some(1)); - let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); - let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); - assert!(result.is_empty()); - Ok(()) - } - - fn get_indexed_field_test_failure( - schema: Schema, - expr: Arc, - key: ScalarValue, - expected: &str, - ) -> Result<()> { - let builder = StringBuilder::with_capacity(3, 1024); - let mut lb = ListBuilder::new(builder); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?; - let expr = Arc::new(GetIndexedFieldExpr::new(expr, key)); - let r = expr.evaluate(&batch).map(|_| ()); - assert!(r.is_err()); - assert_eq!(format!("{}", r.unwrap_err()), expected); - Ok(()) - } - - #[test] - fn get_indexed_field_invalid_scalar() -> Result<()> { - let schema = list_schema("l"); - let expr = lit("a"); - get_indexed_field_test_failure( - schema, expr, ScalarValue::Int64(Some(0)), - "Execution error: get indexed field is only possible on lists with int64 indexes or \ - struct with utf8 indexes. Tried Utf8 with Int64(0) index") - } - - #[test] - fn get_indexed_field_invalid_list_index() -> Result<()> { - let schema = list_schema("l"); - let expr = col("l", &schema).unwrap(); - get_indexed_field_test_failure( - schema, expr, ScalarValue::Int8(Some(0)), - "Execution error: get indexed field is only possible on lists with int64 indexes. \ - Tried with Int8(0) index") - } - - fn build_struct( - fields: Vec, - list_of_tuples: Vec<(Option, Vec>)>, - ) -> StructArray { - let foo_builder = Int64Array::builder(list_of_tuples.len()); - let str_builder = StringBuilder::with_capacity(list_of_tuples.len(), 1024); - let bar_builder = ListBuilder::new(str_builder); - let mut builder = StructBuilder::new( - fields, - vec![Box::new(foo_builder), Box::new(bar_builder)], + let list_of_start_indices = vec![Some(1), Some(2), None]; + let list_of_stop_indices = vec![None]; + let expected_list = vec![Some("a"), Some("c"), None]; + + let schema = list_schema(&["list", "key"]); + let (list_col, key_col, _) = build_list_arguments( + list_of_lists, + list_of_start_indices, + list_of_stop_indices, ); - for (int_value, list_value) in list_of_tuples { - let fb = builder.field_builder::(0).unwrap(); - match int_value { - None => fb.append_null(), - Some(v) => fb.append_value(v), - }; - builder.append(true); - let lb = builder - .field_builder::>(1) - .unwrap(); - for str_value in list_value { - match str_value { - None => lb.values().append_null(), - Some(v) => lb.values().append_value(v), - }; - } - lb.append(true); - } - builder.finish() - } - - fn get_indexed_field_mixed_test( - list_of_tuples: Vec<(Option, Vec>)>, - expected_strings: Vec>>, - expected_ints: Vec>, - ) -> Result<()> { - let struct_col = "s"; - let fields = vec![ - Field::new("foo", DataType::Int64, true), - Field::new_list("bar", Field::new("item", DataType::Utf8, true), true), - ]; - let schema = Schema::new(vec![Field::new( - struct_col, - DataType::Struct(fields.clone().into()), - true, - )]); - let struct_col = build_struct(fields, list_of_tuples.clone()); - - let struct_col_expr = col("s", &schema).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_col)])?; - - let int_field_key = ScalarValue::Utf8(Some("foo".to_string())); - let get_field_expr = Arc::new(GetIndexedFieldExpr::new( - struct_col_expr.clone(), - int_field_key, - )); - let result = get_field_expr - .evaluate(&batch)? - .into_array(batch.num_rows()); - let result = as_int64_array(&result)?; - let expected = &Int64Array::from(expected_ints); - assert_eq!(expected, result); - - let list_field_key = ScalarValue::Utf8(Some("bar".to_string())); - let get_list_expr = - Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key)); - let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows()); - let result = as_list_array(&result)?; - let expected = - &build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect()); - assert_eq!(expected, result); - - for (i, expected) in expected_strings.into_iter().enumerate() { - let get_nested_str_expr = Arc::new(GetIndexedFieldExpr::new( - get_list_expr.clone(), - ScalarValue::Int64(Some((i + 1) as i64)), - )); - let result = get_nested_str_expr - .evaluate(&batch)? - .into_array(batch.num_rows()); - let result = as_string_array(&result)?; - let expected = &StringArray::from(expected); - assert_eq!(expected, result); - } + let expr = col("list", &schema).unwrap(); + let key = col("key", &schema).unwrap(); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(list_col), Arc::new(key_col)], + )?; + let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); + let result = expr.evaluate(&batch)?.into_array(1); + let result = as_string_array(&result).expect("failed to downcast to ListArray"); + let expected = StringArray::from(expected_list); + assert_eq!(expected, result.clone()); Ok(()) } #[test] - fn get_indexed_field_struct() -> Result<()> { - let list_of_structs = vec![ - (Some(10), vec![Some("a"), Some("b"), None]), - (Some(15), vec![None, Some("c"), Some("d")]), - (None, vec![Some("e"), None, Some("f")]), + fn get_indexed_field_list_range() -> Result<()> { + let list_of_lists = vec![ + vec![Some("a"), Some("b"), None], + vec![None, Some("c"), Some("d")], + vec![Some("e"), None, Some("f")], ]; - + let list_of_start_indices = vec![Some(1), Some(2), None]; + let list_of_stop_indices = vec![Some(2), None, Some(3)]; let expected_list = vec![ - vec![Some("a"), None, Some("e")], - vec![Some("b"), Some("c"), None], - vec![None, Some("d"), Some("f")], + vec![Some("a"), Some("b")], + vec![Some("c"), Some("d")], + vec![Some("e"), None, Some("f")], ]; - let expected_ints = vec![Some(10), Some(15), None]; - - get_indexed_field_mixed_test( - list_of_structs.clone(), - expected_list, - expected_ints, + let schema = list_schema(&["list", "start", "stop"]); + let (list_col, start_col, stop_col) = build_list_arguments( + list_of_lists, + list_of_start_indices, + list_of_stop_indices, + ); + let expr = col("list", &schema).unwrap(); + let start = col("start", &schema).unwrap(); + let stop = col("stop", &schema).unwrap(); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(list_col), Arc::new(start_col), Arc::new(stop_col)], )?; + let expr = Arc::new(GetIndexedFieldExpr::new_range(expr, start, stop)); + let result = expr.evaluate(&batch)?.into_array(1); + let result = as_list_array(&result).expect("failed to downcast to ListArray"); + let (expected, _, _) = + build_list_arguments(expected_list, vec![None], vec![None]); + assert_eq!(expected, result.clone()); Ok(()) } #[test] - fn get_indexed_field_list_out_of_bounds() { - let fields = vec![ - Field::new("id", DataType::Int64, true), - Field::new_list("a", Field::new("item", DataType::Float64, true), true), - ]; - - let schema = Schema::new(fields); - let mut int_builder = PrimitiveBuilder::::new(); - int_builder.append_value(1); - - let mut lb = ListBuilder::new(PrimitiveBuilder::::new()); - lb.values().append_value(1.0); - lb.values().append_null(); - lb.values().append_value(3.0); - lb.append(true); - + fn get_indexed_field_empty_list() -> Result<()> { + let schema = list_schema(&["list", "key"]); + let builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(builder); + let key_array = new_empty_array(&DataType::Int64); + let expr = col("list", &schema).unwrap(); + let key = col("key", &schema).unwrap(); let batch = RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(int_builder.finish()), Arc::new(lb.finish())], - ) - .unwrap(); - - let col_a = col("a", &schema).unwrap(); - // out of bounds index - verify_index_evaluation(&batch, col_a.clone(), 0, float64_array(None)); - - verify_index_evaluation(&batch, col_a.clone(), 1, float64_array(Some(1.0))); - verify_index_evaluation(&batch, col_a.clone(), 2, float64_array(None)); - verify_index_evaluation(&batch, col_a.clone(), 3, float64_array(Some(3.0))); - - // out of bounds index - verify_index_evaluation(&batch, col_a.clone(), 100, float64_array(None)); + Arc::new(schema), + vec![Arc::new(list_builder.finish()), key_array], + )?; + let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); + let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); + assert!(result.is_null(0)); + Ok(()) } - fn verify_index_evaluation( - batch: &RecordBatch, - arg: Arc, - index: i64, - expected_result: ArrayRef, - ) { - let expr = Arc::new(GetIndexedFieldExpr::new( - arg, - ScalarValue::Int64(Some(index)), - )); - let result = expr.evaluate(batch).unwrap().into_array(batch.num_rows()); - assert!( - result == expected_result.clone(), - "result: {result:?} != expected result: {expected_result:?}" - ); - assert_eq!(result.data_type(), &DataType::Float64); - } + #[test] + fn get_indexed_field_invalid_list_index() -> Result<()> { + let schema = list_schema(&["list", "error"]); + let expr = col("list", &schema).unwrap(); + let key = col("error", &schema).unwrap(); + let builder = StringBuilder::with_capacity(3, 1024); + let mut list_builder = ListBuilder::new(builder); + list_builder.values().append_value("hello"); + list_builder.append(true); - fn float64_array(value: Option) -> ArrayRef { - match value { - Some(v) => Arc::new(Float64Array::from_value(v, 1)), - None => { - let mut b = PrimitiveBuilder::::new(); - b.append_null(); - Arc::new(b.finish()) - } - } + let key_array = Int64Array::from(vec![Some(3)]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(list_builder.finish()), Arc::new(key_array)], + )?; + let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); + let result = expr.evaluate(&batch)?.into_array(1); + assert!(result.is_null(0)); + Ok(()) } } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 3eb25cea82cf..722edb22dea6 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -377,6 +377,7 @@ mod tests { use super::*; use crate::expressions; use crate::expressions::{col, lit, try_cast}; + use datafusion_common::plan_err; use datafusion_common::Result; use datafusion_expr::type_coercion::binary::comparison_coercion; @@ -396,9 +397,9 @@ mod tests { .collect(); let result_type = get_coerce_type(expr_type, &list_types); match result_type { - None => Err(DataFusionError::Plan(format!( + None => plan_err!( "Can not find compatible types to compare {expr_type:?} with {list_types:?}" - ))), + ), Some(data_type) => { // find the coerced type let cast_expr = try_cast(expr, input_schema, data_type.clone())?; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c660cfadcca1..b7e9d2cd8010 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -60,6 +60,7 @@ pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; +pub use crate::aggregate::regr::Regr; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; pub use crate::aggregate::sum::Sum; @@ -81,7 +82,7 @@ pub use case::{case, CaseExpr}; pub use cast::{cast, cast_column, cast_with_options, CastExpr}; pub use column::{col, Column, UnKnownColumn}; pub use datetime::{date_time_interval_expr, DateTimeIntervalExpr}; -pub use get_indexed_field::GetIndexedFieldExpr; +pub use get_indexed_field::{GetFieldAccessExpr, GetIndexedFieldExpr}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 584d1d66955d..497fb42fe4df 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -28,7 +28,7 @@ use arrow::{ use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::ColumnarValue; /// A place holder expression, can not be evaluated. @@ -65,9 +65,7 @@ impl PhysicalExpr for NoOp { } fn evaluate(&self, _batch: &RecordBatch) -> Result { - Err(DataFusionError::Plan( - "NoOp::evaluate() should not be called".to_owned(), - )) + plan_err!("NoOp::evaluate() should not be called") } fn children(&self) -> Vec> { diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c9683d2cdbc9..df76d55bfcaa 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -380,6 +380,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), + BuiltinScalarFunction::Nanvl => { + Arc::new(|args| make_scalar_function(math_expressions::nanvl)(args)) + } BuiltinScalarFunction::Radians => Arc::new(math_expressions::to_radians), BuiltinScalarFunction::Random => Arc::new(math_expressions::random), BuiltinScalarFunction::Round => { @@ -428,7 +431,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayDims => { Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args)) } - BuiltinScalarFunction::ArrayFill => Arc::new(array_expressions::array_fill), + BuiltinScalarFunction::ArrayElement => { + Arc::new(|args| make_scalar_function(array_expressions::array_element)(args)) + } BuiltinScalarFunction::ArrayLength => { Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) } @@ -444,6 +449,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayPrepend => { Arc::new(|args| make_scalar_function(array_expressions::array_prepend)(args)) } + BuiltinScalarFunction::ArrayRepeat => { + Arc::new(|args| make_scalar_function(array_expressions::array_repeat)(args)) + } BuiltinScalarFunction::ArrayRemove => { Arc::new(|args| make_scalar_function(array_expressions::array_remove)(args)) } @@ -462,6 +470,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayReplaceAll => Arc::new(|args| { make_scalar_function(array_expressions::array_replace_all)(args) }), + BuiltinScalarFunction::ArraySlice => { + Arc::new(|args| make_scalar_function(array_expressions::array_slice)(args)) + } BuiltinScalarFunction::ArrayToString => Arc::new(|args| { make_scalar_function(array_expressions::array_to_string)(args) }), @@ -471,12 +482,11 @@ pub fn create_physical_fun( BuiltinScalarFunction::MakeArray => { Arc::new(|args| make_scalar_function(array_expressions::make_array)(args)) } - BuiltinScalarFunction::TrimArray => { - Arc::new(|args| make_scalar_function(array_expressions::trim_array)(args)) - } - // string functions + // struct functions BuiltinScalarFunction::Struct => Arc::new(struct_expressions::struct_expr), + + // string functions BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::ascii::)(args) @@ -894,6 +904,7 @@ mod tests { record_batch::RecordBatch, }; use datafusion_common::cast::as_uint64_array; + use datafusion_common::plan_err; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::Signature; @@ -2843,16 +2854,16 @@ mod tests { match expr { Ok(..) => { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Builtin scalar function {fun} does not support empty arguments" - ))); + ); } Err(DataFusionError::Plan(err)) => { if !err .contains("No function matches the given name and argument types") { - return Err(DataFusionError::Internal(format!( - "Builtin scalar function {fun} didn't got the right error message with empty arguments"))); + return plan_err!( + "Builtin scalar function {fun} didn't got the right error message with empty arguments"); } } Err(..) => { diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 883c016c047b..03e0bb64551b 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -259,6 +259,53 @@ pub fn lcm(args: &[ArrayRef]) -> Result { } } +/// Nanvl SQL function +pub fn nanvl(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => { + let compute_nanvl = |x: f64, y: f64| { + if x.is_nan() { + y + } else { + x + } + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Float64Array, + { compute_nanvl } + )) as ArrayRef) + } + + DataType::Float32 => { + let compute_nanvl = |x: f32, y: f32| { + if x.is_nan() { + y + } else { + x + } + }; + + Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "x", + "y", + Float32Array, + { compute_nanvl } + )) as ArrayRef) + } + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function nanvl" + ))), + } +} + /// Pi SQL function pub fn pi(args: &[ColumnarValue]) -> Result { if !matches!(&args[0], ColumnarValue::Array(_)) { @@ -958,4 +1005,40 @@ mod tests { assert_eq!(floats.value(3), 123.0); assert_eq!(floats.value(4), -321.0); } + + #[test] + fn test_nanvl_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y + Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x + ]; + + let result = nanvl(&args).expect("failed to initialize function atan2"); + let floats = + as_float64_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 1.0); + assert_eq!(floats.value(1), 6.0); + assert_eq!(floats.value(2), 3.0); + assert!(floats.value(3).is_nan()); + } + + #[test] + fn test_nanvl_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y + Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x + ]; + + let result = nanvl(&args).expect("failed to initialize function atan2"); + let floats = + as_float32_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), 1.0); + assert_eq!(floats.value(1), 6.0); + assert_eq!(floats.value(2), 3.0); + assert!(floats.value(3).is_nan()); + } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 72a5ccef3463..f2211701fa1a 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -19,17 +19,20 @@ use crate::var_provider::is_system_variables; use crate::{ execution_props::ExecutionProps, expressions::{ - self, binary, date_time_interval_expr, like, Column, GetIndexedFieldExpr, Literal, + self, binary, date_time_interval_expr, like, Column, GetFieldAccessExpr, + GetIndexedFieldExpr, Literal, }, functions, udf, var_provider::VarType, PhysicalExpr, }; use arrow::datatypes::{DataType, Schema}; +use datafusion_common::plan_err; use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction, ScalarUDF}; use datafusion_expr::{ - binary_expr, Between, BinaryExpr, Expr, GetIndexedField, Like, Operator, TryCast, + binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, + Operator, TryCast, }; use std::sync::Arc; @@ -75,9 +78,7 @@ pub fn create_physical_expr( let scalar_value = provider.get_value(variable_names.clone())?; Ok(Arc::new(Literal::new(scalar_value))) } - _ => Err(DataFusionError::Plan( - "No system variable provider found".to_string(), - )), + _ => plan_err!("No system variable provider found"), } } else { match execution_props.get_var_provider(VarType::UserDefined) { @@ -85,9 +86,7 @@ pub fn create_physical_expr( let scalar_value = provider.get_value(variable_names.clone())?; Ok(Arc::new(Literal::new(scalar_value))) } - _ => Err(DataFusionError::Plan( - "No user defined variable provider found".to_string(), - )), + _ => plan_err!("No user defined variable provider found"), } } } @@ -341,7 +340,36 @@ pub fn create_physical_expr( input_schema, execution_props, )?), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let field = match field { + GetFieldAccess::NamedStructField { name } => { + GetFieldAccessExpr::NamedStructField { name: name.clone() } + } + GetFieldAccess::ListIndex { key } => GetFieldAccessExpr::ListIndex { + key: create_physical_expr( + key, + input_dfschema, + input_schema, + execution_props, + )?, + }, + GetFieldAccess::ListRange { start, stop } => { + GetFieldAccessExpr::ListRange { + start: create_physical_expr( + start, + input_dfschema, + input_schema, + execution_props, + )?, + stop: create_physical_expr( + stop, + input_dfschema, + input_schema, + execution_props, + )?, + } + } + }; Ok(Arc::new(GetIndexedFieldExpr::new( create_physical_expr( expr, @@ -349,7 +377,7 @@ pub fn create_physical_expr( input_schema, execution_props, )?, - key.clone(), + field, ))) } diff --git a/datafusion/physical-expr/src/regex_expressions.rs b/datafusion/physical-expr/src/regex_expressions.rs index 3965897093ac..419eafb1c8d3 100644 --- a/datafusion/physical-expr/src/regex_expressions.rs +++ b/datafusion/physical-expr/src/regex_expressions.rs @@ -26,6 +26,7 @@ use arrow::array::{ OffsetSizeTrait, }; use arrow::compute; +use datafusion_common::plan_err; use datafusion_common::{cast::as_generic_string_array, DataFusionError, Result}; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; use hashbrown::HashMap; @@ -65,7 +66,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { match flags { Some(f) if f.iter().any(|s| s == Some("g")) => { - Err(DataFusionError::Plan("regexp_match() does not support the \"global\" option".to_owned())) + plan_err!("regexp_match() does not support the \"global\" option") }, _ => compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError), } diff --git a/datafusion/physical-expr/src/sort_expr.rs b/datafusion/physical-expr/src/sort_expr.rs index 0ba4627bc7d4..83d32dfeec17 100644 --- a/datafusion/physical-expr/src/sort_expr.rs +++ b/datafusion/physical-expr/src/sort_expr.rs @@ -24,6 +24,7 @@ use crate::PhysicalExpr; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::record_batch::RecordBatch; +use datafusion_common::plan_err; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::ColumnarValue; @@ -64,9 +65,9 @@ impl PhysicalSortExpr { let array_to_sort = match value_to_sort { ColumnarValue::Array(array) => array, ColumnarValue::Scalar(scalar) => { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Sort operation is not applicable to scalar value {scalar}" - ))); + ); } }; Ok(SortColumn { diff --git a/datafusion/physical-expr/src/struct_expressions.rs b/datafusion/physical-expr/src/struct_expressions.rs index d4cfa309ff44..e2dd98cdf1ba 100644 --- a/datafusion/physical-expr/src/struct_expressions.rs +++ b/datafusion/physical-expr/src/struct_expressions.rs @@ -78,3 +78,55 @@ pub fn struct_expr(values: &[ColumnarValue]) -> Result { .collect(); Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::as_struct_array; + use datafusion_common::ScalarValue; + + #[test] + fn test_struct() { + // struct(1, 2, 3) = {"c0": 1, "c1": 2, "c2": 3} + let args = [ + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ]; + let struc = struct_expr(&args) + .expect("failed to initialize function struct") + .into_array(1); + let result = + as_struct_array(&struc).expect("failed to initialize function struct"); + assert_eq!( + &Int64Array::from(vec![1]), + result + .column_by_name("c0") + .unwrap() + .clone() + .as_any() + .downcast_ref::() + .unwrap() + ); + assert_eq!( + &Int64Array::from(vec![2]), + result + .column_by_name("c1") + .unwrap() + .clone() + .as_any() + .downcast_ref::() + .unwrap() + ); + assert_eq!( + &Int64Array::from(vec![3]), + result + .column_by_name("c2") + .unwrap() + .clone() + .as_any() + .downcast_ref::() + .unwrap() + ); + } +} diff --git a/datafusion/proto/.gitignore b/datafusion/proto/.gitignore new file mode 100644 index 000000000000..3aa373dc479b --- /dev/null +++ b/datafusion/proto/.gitignore @@ -0,0 +1,4 @@ +# Files generated by regen.sh +proto/proto_descriptor.bin +src/datafusion.rs +datafusion.serde.rs diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e9ae76b25d18..2254a8cd3f30 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -400,11 +400,26 @@ message RollupNode { repeated LogicalExprNode expr = 1; } +message NamedStructField { + ScalarValue name = 1; +} + +message ListIndex { + LogicalExprNode key = 1; +} +message ListRange { + LogicalExprNode start = 1; + LogicalExprNode stop = 2; +} message GetIndexedField { LogicalExprNode expr = 1; - ScalarValue key = 2; + oneof field { + NamedStructField named_struct_field = 2; + ListIndex list_index = 3; + ListRange list_range = 4; + } } message IsNull { @@ -556,7 +571,7 @@ enum ScalarFunction { ArrayAppend = 86; ArrayConcat = 87; ArrayDims = 88; - ArrayFill = 89; + ArrayRepeat = 89; ArrayLength = 90; ArrayNdims = 91; ArrayPosition = 92; @@ -566,7 +581,8 @@ enum ScalarFunction { ArrayReplace = 96; ArrayToString = 97; Cardinality = 98; - TrimArray = 99; + ArrayElement = 99; + ArraySlice = 100; Encode = 101; Decode = 102; Cot = 103; @@ -577,6 +593,7 @@ enum ScalarFunction { ArrayReplaceN = 108; ArrayRemoveAll = 109; ArrayReplaceAll = 110; + Nanvl = 111; } message ScalarFunctionNode { @@ -613,6 +630,15 @@ enum AggregateFunction { // we append "_AGG" to obey name scoping rules. FIRST_VALUE_AGG = 24; LAST_VALUE_AGG = 25; + REGR_SLOPE = 26; + REGR_INTERCEPT = 27; + REGR_COUNT = 28; + REGR_R2 = 29; + REGR_AVGX = 30; + REGR_AVGY = 31; + REGR_SXX = 32; + REGR_SYY = 33; + REGR_SXY = 34; } message AggregateExprNode { @@ -1485,7 +1511,24 @@ message ColumnStats { uint32 distinct_count = 4; } +message NamedStructFieldExpr { + ScalarValue name = 1; +} + +message ListIndexExpr { + PhysicalExprNode key = 1; +} + +message ListRangeExpr { + PhysicalExprNode start = 1; + PhysicalExprNode stop = 2; +} + message PhysicalGetIndexedFieldExprNode { PhysicalExprNode arg = 1; - ScalarValue key = 2; + oneof field { + NamedStructFieldExpr named_struct_field_expr = 2; + ListIndexExpr list_index_expr = 3; + ListRangeExpr list_range_expr = 4; + } } diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 566dffb5350a..4bd741771909 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -455,7 +455,7 @@ mod test { #[test] fn roundtrip_deeply_nested() { // we need more stack space so this doesn't overflow in dev builds - std::thread::Builder::new().stack_size(10_000_000).spawn(|| { + std::thread::Builder::new().stack_size(20_000_000).spawn(|| { // don't know what "too much" is, so let's slowly try to increase complexity let n_max = 100; diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 9d2cb06f6314..024bb949baa9 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -18,6 +18,7 @@ use std::{collections::HashSet, sync::Arc}; use datafusion::execution::registry::FunctionRegistry; +use datafusion_common::plan_err; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; @@ -31,20 +32,14 @@ impl FunctionRegistry for NoRegistry { } fn udf(&self, name: &str) -> Result> { - Err(DataFusionError::Plan( - format!("No function registry provided to deserialize, so can not deserialize User Defined Function '{name}'")) - ) + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Function '{name}'") } fn udaf(&self, name: &str) -> Result> { - Err(DataFusionError::Plan( - format!("No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'")) - ) + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Aggregate Function '{name}'") } fn udwf(&self, name: &str) -> Result> { - Err(DataFusionError::Plan( - format!("No function registry provided to deserialize, so can not deserialize User Defined Window Function '{name}'")) - ) + plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Window Function '{name}'") } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index a5d85cc6cf14..d65e45d51d42 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -465,6 +465,15 @@ impl serde::Serialize for AggregateFunction { Self::BoolOr => "BOOL_OR", Self::FirstValueAgg => "FIRST_VALUE_AGG", Self::LastValueAgg => "LAST_VALUE_AGG", + Self::RegrSlope => "REGR_SLOPE", + Self::RegrIntercept => "REGR_INTERCEPT", + Self::RegrCount => "REGR_COUNT", + Self::RegrR2 => "REGR_R2", + Self::RegrAvgx => "REGR_AVGX", + Self::RegrAvgy => "REGR_AVGY", + Self::RegrSxx => "REGR_SXX", + Self::RegrSyy => "REGR_SYY", + Self::RegrSxy => "REGR_SXY", }; serializer.serialize_str(variant) } @@ -502,6 +511,15 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BOOL_OR", "FIRST_VALUE_AGG", "LAST_VALUE_AGG", + "REGR_SLOPE", + "REGR_INTERCEPT", + "REGR_COUNT", + "REGR_R2", + "REGR_AVGX", + "REGR_AVGY", + "REGR_SXX", + "REGR_SYY", + "REGR_SXY", ]; struct GeneratedVisitor; @@ -570,6 +588,15 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BOOL_OR" => Ok(AggregateFunction::BoolOr), "FIRST_VALUE_AGG" => Ok(AggregateFunction::FirstValueAgg), "LAST_VALUE_AGG" => Ok(AggregateFunction::LastValueAgg), + "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope), + "REGR_INTERCEPT" => Ok(AggregateFunction::RegrIntercept), + "REGR_COUNT" => Ok(AggregateFunction::RegrCount), + "REGR_R2" => Ok(AggregateFunction::RegrR2), + "REGR_AVGX" => Ok(AggregateFunction::RegrAvgx), + "REGR_AVGY" => Ok(AggregateFunction::RegrAvgy), + "REGR_SXX" => Ok(AggregateFunction::RegrSxx), + "REGR_SYY" => Ok(AggregateFunction::RegrSyy), + "REGR_SXY" => Ok(AggregateFunction::RegrSxy), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -7270,15 +7297,25 @@ impl serde::Serialize for GetIndexedField { if self.expr.is_some() { len += 1; } - if self.key.is_some() { + if self.field.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.GetIndexedField", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.key.as_ref() { - struct_ser.serialize_field("key", v)?; + if let Some(v) = self.field.as_ref() { + match v { + get_indexed_field::Field::NamedStructField(v) => { + struct_ser.serialize_field("namedStructField", v)?; + } + get_indexed_field::Field::ListIndex(v) => { + struct_ser.serialize_field("listIndex", v)?; + } + get_indexed_field::Field::ListRange(v) => { + struct_ser.serialize_field("listRange", v)?; + } + } } struct_ser.end() } @@ -7291,13 +7328,20 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { { const FIELDS: &[&str] = &[ "expr", - "key", + "named_struct_field", + "namedStructField", + "list_index", + "listIndex", + "list_range", + "listRange", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, - Key, + NamedStructField, + ListIndex, + ListRange, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7320,7 +7364,9 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { { match value { "expr" => Ok(GeneratedField::Expr), - "key" => Ok(GeneratedField::Key), + "namedStructField" | "named_struct_field" => Ok(GeneratedField::NamedStructField), + "listIndex" | "list_index" => Ok(GeneratedField::ListIndex), + "listRange" | "list_range" => Ok(GeneratedField::ListRange), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7341,7 +7387,7 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { V: serde::de::MapAccess<'de>, { let mut expr__ = None; - let mut key__ = None; + let mut field__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::Expr => { @@ -7350,17 +7396,32 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { } expr__ = map.next_value()?; } - GeneratedField::Key => { - if key__.is_some() { - return Err(serde::de::Error::duplicate_field("key")); + GeneratedField::NamedStructField => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("namedStructField")); } - key__ = map.next_value()?; + field__ = map.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::NamedStructField) +; + } + GeneratedField::ListIndex => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listIndex")); + } + field__ = map.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListIndex) +; + } + GeneratedField::ListRange => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listRange")); + } + field__ = map.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListRange) +; } } } Ok(GetIndexedField { expr: expr__, - key: key__, + field: field__, }) } } @@ -10030,41 +10091,423 @@ impl<'de> serde::Deserialize<'de> for LimitNode { let mut fetch__ = None; while let Some(k) = map.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map.next_value()?; - } - GeneratedField::Skip => { - if skip__.is_some() { - return Err(serde::de::Error::duplicate_field("skip")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map.next_value()?; + } + GeneratedField::Skip => { + if skip__.is_some() { + return Err(serde::de::Error::duplicate_field("skip")); + } + skip__ = + Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(LimitNode { + input: input__, + skip: skip__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.LimitNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for List { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.List", len)?; + if let Some(v) = self.field_type.as_ref() { + struct_ser.serialize_field("fieldType", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for List { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_type", + "fieldType", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldType, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = List; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.List") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_type__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::FieldType => { + if field_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldType")); + } + field_type__ = map.next_value()?; + } + } + } + Ok(List { + field_type: field_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.List", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListIndex { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.key.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListIndex", len)?; + if let Some(v) = self.key.as_ref() { + struct_ser.serialize_field("key", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListIndex { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListIndex; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListIndex") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = map.next_value()?; + } + } + } + Ok(ListIndex { + key: key__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListIndex", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListIndexExpr { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.key.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListIndexExpr", len)?; + if let Some(v) = self.key.as_ref() { + struct_ser.serialize_field("key", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListIndexExpr { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListIndexExpr; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListIndexExpr") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = map.next_value()?; + } + } + } + Ok(ListIndexExpr { + key: key__, + }) + } + } + deserializer.deserialize_struct("datafusion.ListIndexExpr", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ListRange { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.start.is_some() { + len += 1; + } + if self.stop.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListRange", len)?; + if let Some(v) = self.start.as_ref() { + struct_ser.serialize_field("start", v)?; + } + if let Some(v) = self.stop.as_ref() { + struct_ser.serialize_field("stop", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ListRange { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "start", + "stop", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Start, + Stop, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "start" => Ok(GeneratedField::Start), + "stop" => Ok(GeneratedField::Stop), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ListRange; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ListRange") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut start__ = None; + let mut stop__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); } - skip__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + start__ = map.next_value()?; } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); + GeneratedField::Stop => { + if stop__.is_some() { + return Err(serde::de::Error::duplicate_field("stop")); } - fetch__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + stop__ = map.next_value()?; } } } - Ok(LimitNode { - input: input__, - skip: skip__.unwrap_or_default(), - fetch: fetch__.unwrap_or_default(), + Ok(ListRange { + start: start__, + stop: stop__, }) } } - deserializer.deserialize_struct("datafusion.LimitNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ListRange", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for List { +impl serde::Serialize for ListRangeExpr { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10072,30 +10515,37 @@ impl serde::Serialize for List { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field_type.is_some() { + if self.start.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.List", len)?; - if let Some(v) = self.field_type.as_ref() { - struct_ser.serialize_field("fieldType", v)?; + if self.stop.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListRangeExpr", len)?; + if let Some(v) = self.start.as_ref() { + struct_ser.serialize_field("start", v)?; + } + if let Some(v) = self.stop.as_ref() { + struct_ser.serialize_field("stop", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for List { +impl<'de> serde::Deserialize<'de> for ListRangeExpr { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field_type", - "fieldType", + "start", + "stop", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FieldType, + Start, + Stop, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10117,7 +10567,8 @@ impl<'de> serde::Deserialize<'de> for List { E: serde::de::Error, { match value { - "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + "start" => Ok(GeneratedField::Start), + "stop" => Ok(GeneratedField::Stop), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10127,33 +10578,41 @@ impl<'de> serde::Deserialize<'de> for List { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = List; + type Value = ListRangeExpr; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.List") + formatter.write_str("struct datafusion.ListRangeExpr") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field_type__ = None; + let mut start__ = None; + let mut stop__ = None; while let Some(k) = map.next_key()? { match k { - GeneratedField::FieldType => { - if field_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldType")); + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); } - field_type__ = map.next_value()?; + start__ = map.next_value()?; + } + GeneratedField::Stop => { + if stop__.is_some() { + return Err(serde::de::Error::duplicate_field("stop")); + } + stop__ = map.next_value()?; } } } - Ok(List { - field_type: field_type__, + Ok(ListRangeExpr { + start: start__, + stop: stop__, }) } } - deserializer.deserialize_struct("datafusion.List", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ListRangeExpr", FIELDS, GeneratedVisitor) } } impl serde::Serialize for ListingTableScanNode { @@ -12129,6 +12588,188 @@ impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { deserializer.deserialize_struct("datafusion.MaybePhysicalSortExprs", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for NamedStructField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.name.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructField", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for NamedStructField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NamedStructField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.NamedStructField") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = map.next_value()?; + } + } + } + Ok(NamedStructField { + name: name__, + }) + } + } + deserializer.deserialize_struct("datafusion.NamedStructField", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for NamedStructFieldExpr { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.name.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructFieldExpr", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for NamedStructFieldExpr { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NamedStructFieldExpr; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.NamedStructFieldExpr") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = map.next_value()?; + } + } + } + Ok(NamedStructFieldExpr { + name: name__, + }) + } + } + deserializer.deserialize_struct("datafusion.NamedStructFieldExpr", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for NegativeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -14775,15 +15416,25 @@ impl serde::Serialize for PhysicalGetIndexedFieldExprNode { if self.arg.is_some() { len += 1; } - if self.key.is_some() { + if self.field.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalGetIndexedFieldExprNode", len)?; if let Some(v) = self.arg.as_ref() { struct_ser.serialize_field("arg", v)?; } - if let Some(v) = self.key.as_ref() { - struct_ser.serialize_field("key", v)?; + if let Some(v) = self.field.as_ref() { + match v { + physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(v) => { + struct_ser.serialize_field("namedStructFieldExpr", v)?; + } + physical_get_indexed_field_expr_node::Field::ListIndexExpr(v) => { + struct_ser.serialize_field("listIndexExpr", v)?; + } + physical_get_indexed_field_expr_node::Field::ListRangeExpr(v) => { + struct_ser.serialize_field("listRangeExpr", v)?; + } + } } struct_ser.end() } @@ -14796,13 +15447,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { { const FIELDS: &[&str] = &[ "arg", - "key", + "named_struct_field_expr", + "namedStructFieldExpr", + "list_index_expr", + "listIndexExpr", + "list_range_expr", + "listRangeExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Arg, - Key, + NamedStructFieldExpr, + ListIndexExpr, + ListRangeExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14825,7 +15483,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { { match value { "arg" => Ok(GeneratedField::Arg), - "key" => Ok(GeneratedField::Key), + "namedStructFieldExpr" | "named_struct_field_expr" => Ok(GeneratedField::NamedStructFieldExpr), + "listIndexExpr" | "list_index_expr" => Ok(GeneratedField::ListIndexExpr), + "listRangeExpr" | "list_range_expr" => Ok(GeneratedField::ListRangeExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14846,7 +15506,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { V: serde::de::MapAccess<'de>, { let mut arg__ = None; - let mut key__ = None; + let mut field__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::Arg => { @@ -14855,17 +15515,32 @@ impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { } arg__ = map.next_value()?; } - GeneratedField::Key => { - if key__.is_some() { - return Err(serde::de::Error::duplicate_field("key")); + GeneratedField::NamedStructFieldExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("namedStructFieldExpr")); } - key__ = map.next_value()?; + field__ = map.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr) +; + } + GeneratedField::ListIndexExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listIndexExpr")); + } + field__ = map.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListIndexExpr) +; + } + GeneratedField::ListRangeExpr => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("listRangeExpr")); + } + field__ = map.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::ListRangeExpr) +; } } } Ok(PhysicalGetIndexedFieldExprNode { arg: arg__, - key: key__, + field: field__, }) } } @@ -18260,7 +18935,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayAppend => "ArrayAppend", Self::ArrayConcat => "ArrayConcat", Self::ArrayDims => "ArrayDims", - Self::ArrayFill => "ArrayFill", + Self::ArrayRepeat => "ArrayRepeat", Self::ArrayLength => "ArrayLength", Self::ArrayNdims => "ArrayNdims", Self::ArrayPosition => "ArrayPosition", @@ -18270,7 +18945,8 @@ impl serde::Serialize for ScalarFunction { Self::ArrayReplace => "ArrayReplace", Self::ArrayToString => "ArrayToString", Self::Cardinality => "Cardinality", - Self::TrimArray => "TrimArray", + Self::ArrayElement => "ArrayElement", + Self::ArraySlice => "ArraySlice", Self::Encode => "Encode", Self::Decode => "Decode", Self::Cot => "Cot", @@ -18281,6 +18957,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayReplaceN => "ArrayReplaceN", Self::ArrayRemoveAll => "ArrayRemoveAll", Self::ArrayReplaceAll => "ArrayReplaceAll", + Self::Nanvl => "Nanvl", }; serializer.serialize_str(variant) } @@ -18381,7 +19058,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayAppend", "ArrayConcat", "ArrayDims", - "ArrayFill", + "ArrayRepeat", "ArrayLength", "ArrayNdims", "ArrayPosition", @@ -18391,7 +19068,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplace", "ArrayToString", "Cardinality", - "TrimArray", + "ArrayElement", + "ArraySlice", "Encode", "Decode", "Cot", @@ -18402,6 +19080,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplaceN", "ArrayRemoveAll", "ArrayReplaceAll", + "Nanvl", ]; struct GeneratedVisitor; @@ -18533,7 +19212,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayAppend" => Ok(ScalarFunction::ArrayAppend), "ArrayConcat" => Ok(ScalarFunction::ArrayConcat), "ArrayDims" => Ok(ScalarFunction::ArrayDims), - "ArrayFill" => Ok(ScalarFunction::ArrayFill), + "ArrayRepeat" => Ok(ScalarFunction::ArrayRepeat), "ArrayLength" => Ok(ScalarFunction::ArrayLength), "ArrayNdims" => Ok(ScalarFunction::ArrayNdims), "ArrayPosition" => Ok(ScalarFunction::ArrayPosition), @@ -18543,7 +19222,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplace" => Ok(ScalarFunction::ArrayReplace), "ArrayToString" => Ok(ScalarFunction::ArrayToString), "Cardinality" => Ok(ScalarFunction::Cardinality), - "TrimArray" => Ok(ScalarFunction::TrimArray), + "ArrayElement" => Ok(ScalarFunction::ArrayElement), + "ArraySlice" => Ok(ScalarFunction::ArraySlice), "Encode" => Ok(ScalarFunction::Encode), "Decode" => Ok(ScalarFunction::Decode), "Cot" => Ok(ScalarFunction::Cot), @@ -18554,6 +19234,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplaceN" => Ok(ScalarFunction::ArrayReplaceN), "ArrayRemoveAll" => Ok(ScalarFunction::ArrayRemoveAll), "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), + "Nanvl" => Ok(ScalarFunction::Nanvl), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c6f3a23ed65f..867853b128fc 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -611,11 +611,44 @@ pub struct RollupNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct NamedStructField { + #[prost(message, optional, tag = "1")] + pub name: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListIndex { + #[prost(message, optional, boxed, tag = "1")] + pub key: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListRange { + #[prost(message, optional, boxed, tag = "1")] + pub start: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub stop: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct GetIndexedField { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub key: ::core::option::Option, + #[prost(oneof = "get_indexed_field::Field", tags = "2, 3, 4")] + pub field: ::core::option::Option, +} +/// Nested message and enum types in `GetIndexedField`. +pub mod get_indexed_field { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Field { + #[prost(message, tag = "2")] + NamedStructField(super::NamedStructField), + #[prost(message, tag = "3")] + ListIndex(::prost::alloc::boxed::Box), + #[prost(message, tag = "4")] + ListRange(::prost::alloc::boxed::Box), + } } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -2121,11 +2154,44 @@ pub struct ColumnStats { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct NamedStructFieldExpr { + #[prost(message, optional, tag = "1")] + pub name: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListIndexExpr { + #[prost(message, optional, boxed, tag = "1")] + pub key: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ListRangeExpr { + #[prost(message, optional, boxed, tag = "1")] + pub start: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub stop: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalGetIndexedFieldExprNode { #[prost(message, optional, boxed, tag = "1")] pub arg: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub key: ::core::option::Option, + #[prost(oneof = "physical_get_indexed_field_expr_node::Field", tags = "2, 3, 4")] + pub field: ::core::option::Option, +} +/// Nested message and enum types in `PhysicalGetIndexedFieldExprNode`. +pub mod physical_get_indexed_field_expr_node { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Field { + #[prost(message, tag = "2")] + NamedStructFieldExpr(super::NamedStructFieldExpr), + #[prost(message, tag = "3")] + ListIndexExpr(::prost::alloc::boxed::Box), + #[prost(message, tag = "4")] + ListRangeExpr(::prost::alloc::boxed::Box), + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] @@ -2289,7 +2355,7 @@ pub enum ScalarFunction { ArrayAppend = 86, ArrayConcat = 87, ArrayDims = 88, - ArrayFill = 89, + ArrayRepeat = 89, ArrayLength = 90, ArrayNdims = 91, ArrayPosition = 92, @@ -2299,7 +2365,8 @@ pub enum ScalarFunction { ArrayReplace = 96, ArrayToString = 97, Cardinality = 98, - TrimArray = 99, + ArrayElement = 99, + ArraySlice = 100, Encode = 101, Decode = 102, Cot = 103, @@ -2310,6 +2377,7 @@ pub enum ScalarFunction { ArrayReplaceN = 108, ArrayRemoveAll = 109, ArrayReplaceAll = 110, + Nanvl = 111, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2407,7 +2475,7 @@ impl ScalarFunction { ScalarFunction::ArrayAppend => "ArrayAppend", ScalarFunction::ArrayConcat => "ArrayConcat", ScalarFunction::ArrayDims => "ArrayDims", - ScalarFunction::ArrayFill => "ArrayFill", + ScalarFunction::ArrayRepeat => "ArrayRepeat", ScalarFunction::ArrayLength => "ArrayLength", ScalarFunction::ArrayNdims => "ArrayNdims", ScalarFunction::ArrayPosition => "ArrayPosition", @@ -2417,7 +2485,8 @@ impl ScalarFunction { ScalarFunction::ArrayReplace => "ArrayReplace", ScalarFunction::ArrayToString => "ArrayToString", ScalarFunction::Cardinality => "Cardinality", - ScalarFunction::TrimArray => "TrimArray", + ScalarFunction::ArrayElement => "ArrayElement", + ScalarFunction::ArraySlice => "ArraySlice", ScalarFunction::Encode => "Encode", ScalarFunction::Decode => "Decode", ScalarFunction::Cot => "Cot", @@ -2428,6 +2497,7 @@ impl ScalarFunction { ScalarFunction::ArrayReplaceN => "ArrayReplaceN", ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll", ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", + ScalarFunction::Nanvl => "Nanvl", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2522,7 +2592,7 @@ impl ScalarFunction { "ArrayAppend" => Some(Self::ArrayAppend), "ArrayConcat" => Some(Self::ArrayConcat), "ArrayDims" => Some(Self::ArrayDims), - "ArrayFill" => Some(Self::ArrayFill), + "ArrayRepeat" => Some(Self::ArrayRepeat), "ArrayLength" => Some(Self::ArrayLength), "ArrayNdims" => Some(Self::ArrayNdims), "ArrayPosition" => Some(Self::ArrayPosition), @@ -2532,7 +2602,8 @@ impl ScalarFunction { "ArrayReplace" => Some(Self::ArrayReplace), "ArrayToString" => Some(Self::ArrayToString), "Cardinality" => Some(Self::Cardinality), - "TrimArray" => Some(Self::TrimArray), + "ArrayElement" => Some(Self::ArrayElement), + "ArraySlice" => Some(Self::ArraySlice), "Encode" => Some(Self::Encode), "Decode" => Some(Self::Decode), "Cot" => Some(Self::Cot), @@ -2543,6 +2614,7 @@ impl ScalarFunction { "ArrayReplaceN" => Some(Self::ArrayReplaceN), "ArrayRemoveAll" => Some(Self::ArrayRemoveAll), "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), + "Nanvl" => Some(Self::Nanvl), _ => None, } } @@ -2578,6 +2650,15 @@ pub enum AggregateFunction { /// we append "_AGG" to obey name scoping rules. FirstValueAgg = 24, LastValueAgg = 25, + RegrSlope = 26, + RegrIntercept = 27, + RegrCount = 28, + RegrR2 = 29, + RegrAvgx = 30, + RegrAvgy = 31, + RegrSxx = 32, + RegrSyy = 33, + RegrSxy = 34, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2614,6 +2695,15 @@ impl AggregateFunction { AggregateFunction::BoolOr => "BOOL_OR", AggregateFunction::FirstValueAgg => "FIRST_VALUE_AGG", AggregateFunction::LastValueAgg => "LAST_VALUE_AGG", + AggregateFunction::RegrSlope => "REGR_SLOPE", + AggregateFunction::RegrIntercept => "REGR_INTERCEPT", + AggregateFunction::RegrCount => "REGR_COUNT", + AggregateFunction::RegrR2 => "REGR_R2", + AggregateFunction::RegrAvgx => "REGR_AVGX", + AggregateFunction::RegrAvgy => "REGR_AVGY", + AggregateFunction::RegrSxx => "REGR_SXX", + AggregateFunction::RegrSyy => "REGR_SYY", + AggregateFunction::RegrSxy => "REGR_SXY", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2647,6 +2737,15 @@ impl AggregateFunction { "BOOL_OR" => Some(Self::BoolOr), "FIRST_VALUE_AGG" => Some(Self::FirstValueAgg), "LAST_VALUE_AGG" => Some(Self::LastValueAgg), + "REGR_SLOPE" => Some(Self::RegrSlope), + "REGR_INTERCEPT" => Some(Self::RegrIntercept), + "REGR_COUNT" => Some(Self::RegrCount), + "REGR_R2" => Some(Self::RegrR2), + "REGR_AVGX" => Some(Self::RegrAvgx), + "REGR_AVGY" => Some(Self::RegrAvgy), + "REGR_SXX" => Some(Self::RegrSxx), + "REGR_SYY" => Some(Self::RegrSyy), + "REGR_SXY" => Some(Self::RegrSxy), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 1464f32bb35d..c17d8dbd8ca9 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -36,25 +36,25 @@ use datafusion_common::{ }; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::{ - abs, acos, acosh, array, array_append, array_concat, array_dims, array_fill, + abs, acos, acosh, array, array_append, array_concat, array_dims, array_element, array_has, array_has_all, array_has_any, array_length, array_ndims, array_position, array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, - array_replace, array_replace_all, array_replace_n, array_to_string, ascii, asin, - asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, - character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, - current_date, current_time, date_bin, date_part, date_trunc, degrees, digest, exp, + array_repeat, array_replace, array_replace_all, array_replace_n, array_slice, + array_to_string, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, + cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, + concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, + date_trunc, degrees, digest, exp, expr::{self, InList, Sort, WindowFunction}, factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, md5, now, nullif, octet_length, pi, power, radians, random, - regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, - sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, starts_with, - strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, - to_timestamp_millis, to_timestamp_seconds, translate, trim, trim_array, trunc, upper, - uuid, + lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians, + random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, + rtrim, sha224, sha256, sha384, sha512, signum, sin, sinh, split_part, sqrt, + starts_with, strpos, substr, substring, tan, tanh, to_hex, to_timestamp_micros, + to_timestamp_millis, to_timestamp_seconds, translate, trim, trunc, upper, uuid, window_frame::regularize, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, - Case, Cast, Expr, GetIndexedField, GroupingSet, + Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -456,22 +456,23 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayHasAny => Self::ArrayHasAny, ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, - ScalarFunction::ArrayFill => Self::ArrayFill, + ScalarFunction::ArrayElement => Self::ArrayElement, ScalarFunction::ArrayLength => Self::ArrayLength, ScalarFunction::ArrayNdims => Self::ArrayNdims, ScalarFunction::ArrayPosition => Self::ArrayPosition, ScalarFunction::ArrayPositions => Self::ArrayPositions, ScalarFunction::ArrayPrepend => Self::ArrayPrepend, + ScalarFunction::ArrayRepeat => Self::ArrayRepeat, ScalarFunction::ArrayRemove => Self::ArrayRemove, ScalarFunction::ArrayRemoveN => Self::ArrayRemoveN, ScalarFunction::ArrayRemoveAll => Self::ArrayRemoveAll, ScalarFunction::ArrayReplace => Self::ArrayReplace, ScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, + ScalarFunction::ArraySlice => Self::ArraySlice, ScalarFunction::ArrayToString => Self::ArrayToString, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, - ScalarFunction::TrimArray => Self::TrimArray, ScalarFunction::NullIf => Self::NullIf, ScalarFunction::DatePart => Self::DatePart, ScalarFunction::DateTrunc => Self::DateTrunc, @@ -522,6 +523,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::StructFun => Self::Struct, ScalarFunction::FromUnixtime => Self::FromUnixtime, ScalarFunction::Atan2 => Self::Atan2, + ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, } } @@ -549,6 +551,15 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Stddev => Self::Stddev, protobuf::AggregateFunction::StddevPop => Self::StddevPop, protobuf::AggregateFunction::Correlation => Self::Correlation, + protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, + protobuf::AggregateFunction::RegrIntercept => Self::RegrIntercept, + protobuf::AggregateFunction::RegrCount => Self::RegrCount, + protobuf::AggregateFunction::RegrR2 => Self::RegrR2, + protobuf::AggregateFunction::RegrAvgx => Self::RegrAvgx, + protobuf::AggregateFunction::RegrAvgy => Self::RegrAvgy, + protobuf::AggregateFunction::RegrSxx => Self::RegrSXX, + protobuf::AggregateFunction::RegrSyy => Self::RegrSYY, + protobuf::AggregateFunction::RegrSxy => Self::RegrSXY, protobuf::AggregateFunction::ApproxPercentileCont => { Self::ApproxPercentileCont } @@ -929,18 +940,48 @@ pub fn parse_expr( }) .expect("Binary expression could not be reduced to a single expression.")) } - ExprType::GetIndexedField(field) => { - let key = field - .key - .as_ref() - .ok_or_else(|| Error::required("value"))? - .try_into()?; - - let expr = parse_required_expr(field.expr.as_deref(), registry, "expr")?; + ExprType::GetIndexedField(get_indexed_field) => { + let expr = + parse_required_expr(get_indexed_field.expr.as_deref(), registry, "expr")?; + let field = match &get_indexed_field.field { + Some(protobuf::get_indexed_field::Field::NamedStructField( + named_struct_field, + )) => GetFieldAccess::NamedStructField { + name: named_struct_field + .name + .as_ref() + .ok_or_else(|| Error::required("value"))? + .try_into()?, + }, + Some(protobuf::get_indexed_field::Field::ListIndex(list_index)) => { + GetFieldAccess::ListIndex { + key: Box::new(parse_required_expr( + list_index.key.as_deref(), + registry, + "key", + )?), + } + } + Some(protobuf::get_indexed_field::Field::ListRange(list_range)) => { + GetFieldAccess::ListRange { + start: Box::new(parse_required_expr( + list_range.start.as_deref(), + registry, + "start", + )?), + stop: Box::new(parse_required_expr( + list_range.stop.as_deref(), + registry, + "stop", + )?), + } + } + None => return Err(proto_error("Field must not be None")), + }; Ok(Expr::GetIndexedField(GetIndexedField::new( Box::new(expr), - key, + field, ))) } ExprType::Column(column) => Ok(Expr::Column(column.into())), @@ -1246,10 +1287,6 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), - ScalarFunction::ArrayFill => Ok(array_fill( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::ArrayPosition => Ok(array_position( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1259,6 +1296,10 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArrayRepeat => Ok(array_repeat( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayRemove => Ok(array_remove( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1288,6 +1329,11 @@ pub fn parse_expr( parse_expr(&args[1], registry)?, parse_expr(&args[2], registry)?, )), + ScalarFunction::ArraySlice => Ok(array_slice( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, + )), ScalarFunction::ArrayToString => Ok(array_to_string( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1295,10 +1341,6 @@ pub fn parse_expr( ScalarFunction::Cardinality => { Ok(cardinality(parse_expr(&args[0], registry)?)) } - ScalarFunction::TrimArray => Ok(trim_array( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::ArrayLength => Ok(array_length( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1306,6 +1348,10 @@ pub fn parse_expr( ScalarFunction::ArrayDims => { Ok(array_dims(parse_expr(&args[0], registry)?)) } + ScalarFunction::ArrayElement => Ok(array_element( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::ArrayNdims => { Ok(array_ndims(parse_expr(&args[0], registry)?)) } @@ -1526,6 +1572,10 @@ pub fn parse_expr( ScalarFunction::CurrentDate => Ok(current_date()), ScalarFunction::CurrentTime => Ok(current_time()), ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry)?)), + ScalarFunction::Nanvl => Ok(nanvl( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", )), diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9cbce29ed61a..d00e5e2f5908 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1462,7 +1462,9 @@ mod roundtrip_tests { create_udf, CsvReadOptions, SessionConfig, SessionContext, }; use datafusion::test_util::{TestTableFactory, TestTableProvider}; - use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; + use datafusion_common::{ + plan_err, DFSchemaRef, DataFusionError, Result, ScalarValue, + }; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, ScalarUDF, Sort, @@ -2915,11 +2917,11 @@ mod roundtrip_tests { fn return_type(arg_types: &[DataType]) -> Result> { if arg_types.len() != 1 { - return Err(DataFusionError::Plan(format!( + return plan_err!( "dummy_udwf expects 1 argument, got {}: {:?}", arg_types.len(), arg_types - ))); + ); } Ok(Arc::new(arg_types[0].clone())) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index df5701a282c3..aa1132e8b1f6 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -36,8 +36,8 @@ use arrow::datatypes::{ }; use datafusion_common::{Column, DFField, DFSchemaRef, OwnedTableReference, ScalarValue}; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, InList, Like, - Placeholder, ScalarFunction, ScalarUDF, Sort, + self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet, + InList, Like, Placeholder, ScalarFunction, ScalarUDF, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, @@ -384,6 +384,15 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, + AggregateFunction::RegrSlope => Self::RegrSlope, + AggregateFunction::RegrIntercept => Self::RegrIntercept, + AggregateFunction::RegrCount => Self::RegrCount, + AggregateFunction::RegrR2 => Self::RegrR2, + AggregateFunction::RegrAvgx => Self::RegrAvgx, + AggregateFunction::RegrAvgy => Self::RegrAvgy, + AggregateFunction::RegrSXX => Self::RegrSxx, + AggregateFunction::RegrSYY => Self::RegrSyy, + AggregateFunction::RegrSXY => Self::RegrSxy, AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight @@ -675,6 +684,21 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, AggregateFunction::ApproxMedian => { protobuf::AggregateFunction::ApproxMedian } @@ -951,14 +975,41 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { // see discussion in https://github.com/apache/arrow-datafusion/issues/2565 return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); } - Expr::GetIndexedField(GetIndexedField { key, expr }) => Self { - expr_type: Some(ExprType::GetIndexedField(Box::new( - protobuf::GetIndexedField { - key: Some(key.try_into()?), - expr: Some(Box::new(expr.as_ref().try_into()?)), - }, - ))), - }, + Expr::GetIndexedField(GetIndexedField { expr, field }) => { + let field = match field { + GetFieldAccess::NamedStructField { name } => { + protobuf::get_indexed_field::Field::NamedStructField( + protobuf::NamedStructField { + name: Some(name.try_into()?), + }, + ) + } + GetFieldAccess::ListIndex { key } => { + protobuf::get_indexed_field::Field::ListIndex(Box::new( + protobuf::ListIndex { + key: Some(Box::new(key.as_ref().try_into()?)), + }, + )) + } + GetFieldAccess::ListRange { start, stop } => { + protobuf::get_indexed_field::Field::ListRange(Box::new( + protobuf::ListRange { + start: Some(Box::new(start.as_ref().try_into()?)), + stop: Some(Box::new(stop.as_ref().try_into()?)), + }, + )) + } + }; + + Self { + expr_type: Some(ExprType::GetIndexedField(Box::new( + protobuf::GetIndexedField { + expr: Some(Box::new(expr.as_ref().try_into()?)), + field: Some(field), + }, + ))), + } + } Expr::GroupingSet(GroupingSet::Cube(exprs)) => Self { expr_type: Some(ExprType::Cube(CubeNode { @@ -1404,22 +1455,23 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny, BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, - BuiltinScalarFunction::ArrayFill => Self::ArrayFill, + BuiltinScalarFunction::ArrayElement => Self::ArrayElement, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, BuiltinScalarFunction::ArrayPositions => Self::ArrayPositions, BuiltinScalarFunction::ArrayPrepend => Self::ArrayPrepend, + BuiltinScalarFunction::ArrayRepeat => Self::ArrayRepeat, BuiltinScalarFunction::ArrayRemove => Self::ArrayRemove, BuiltinScalarFunction::ArrayRemoveN => Self::ArrayRemoveN, BuiltinScalarFunction::ArrayRemoveAll => Self::ArrayRemoveAll, BuiltinScalarFunction::ArrayReplace => Self::ArrayReplace, BuiltinScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, + BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, - BuiltinScalarFunction::TrimArray => Self::TrimArray, BuiltinScalarFunction::NullIf => Self::NullIf, BuiltinScalarFunction::DatePart => Self::DatePart, BuiltinScalarFunction::DateTrunc => Self::DateTrunc, @@ -1470,6 +1522,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Struct => Self::StructFun, BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, BuiltinScalarFunction::Atan2 => Self::Atan2, + BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, }; diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 7a52e5f0d09f..e084188ccd04 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -30,7 +30,7 @@ use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::window_function::WindowFunction; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - date_time_interval_expr, GetIndexedFieldExpr, + date_time_interval_expr, GetFieldAccessExpr, GetIndexedFieldExpr, }; use datafusion::physical_plan::expressions::{in_list, LikeExpr}; use datafusion::physical_plan::{ @@ -310,6 +310,36 @@ pub fn parse_physical_expr( )?, )), ExprType::GetIndexedFieldExpr(get_indexed_field_expr) => { + let field = match &get_indexed_field_expr.field { + Some(protobuf::physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(named_struct_field_expr)) => GetFieldAccessExpr::NamedStructField{ + name: convert_required!(named_struct_field_expr.name)?, + }, + Some(protobuf::physical_get_indexed_field_expr_node::Field::ListIndexExpr(list_index_expr)) => GetFieldAccessExpr::ListIndex{ + key: parse_required_physical_expr( + list_index_expr.key.as_deref(), + registry, + "key", + input_schema, + )?}, + Some(protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(list_range_expr)) => GetFieldAccessExpr::ListRange{ + start: parse_required_physical_expr( + list_range_expr.start.as_deref(), + registry, + "start", + input_schema, + )?, + stop: parse_required_physical_expr( + list_range_expr.stop.as_deref(), + registry, + "stop", + input_schema + )?, + }, + None => return Err(proto_error( + "Field must not be None", + )), + }; + Arc::new(GetIndexedFieldExpr::new( parse_required_physical_expr( get_indexed_field_expr.arg.as_deref(), @@ -317,7 +347,7 @@ pub fn parse_physical_expr( "arg", input_schema, )?, - convert_required!(get_indexed_field_expr.key)?, + field, )) } }; diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index e97a773d3472..fdb2ef88cb9e 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1400,7 +1400,8 @@ mod roundtrip_tests { use datafusion::physical_expr::ScalarFunctionExpr; use datafusion::physical_plan::aggregates::PhysicalGroupBy; use datafusion::physical_plan::expressions::{ - date_time_interval_expr, like, BinaryExpr, GetIndexedFieldExpr, + date_time_interval_expr, like, BinaryExpr, GetFieldAccessExpr, + GetIndexedFieldExpr, }; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::physical_plan::projection::ProjectionExec; @@ -1408,7 +1409,7 @@ mod roundtrip_tests { use datafusion::{ arrow::{ compute::kernels::sort::SortOptions, - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Field, Fields, Schema}, }, datasource::{ listing::PartitionedFile, @@ -1919,18 +1920,83 @@ mod roundtrip_tests { } #[test] - fn roundtrip_get_indexed_field() -> Result<()> { + fn roundtrip_get_indexed_field_named_struct_field() -> Result<()> { let fields = vec![ Field::new("id", DataType::Int64, true), - Field::new_list("a", Field::new("item", DataType::Float64, true), true), + Field::new_struct( + "arg", + Fields::from(vec![Field::new("name", DataType::Float64, true)]), + true, + ), + ]; + + let schema = Schema::new(fields); + let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + + let col_arg = col("arg", &schema)?; + let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( + col_arg, + GetFieldAccessExpr::NamedStructField { + name: ScalarValue::Utf8(Some(String::from("name"))), + }, + )); + + let plan = Arc::new(ProjectionExec::try_new( + vec![(get_indexed_field_expr, "result".to_string())], + input, + )?); + + roundtrip_test(plan) + } + + #[test] + fn roundtrip_get_indexed_field_list_index() -> Result<()> { + let fields = vec![ + Field::new("id", DataType::Int64, true), + Field::new_list("arg", Field::new("item", DataType::Float64, true), true), + Field::new("key", DataType::Int64, true), ]; let schema = Schema::new(fields); let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); - let col_a = col("a", &schema)?; - let key = ScalarValue::Int64(Some(1)); - let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new(col_a, key)); + let col_arg = col("arg", &schema)?; + let col_key = col("key", &schema)?; + let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( + col_arg, + GetFieldAccessExpr::ListIndex { key: col_key }, + )); + + let plan = Arc::new(ProjectionExec::try_new( + vec![(get_indexed_field_expr, "result".to_string())], + input, + )?); + + roundtrip_test(plan) + } + + #[test] + fn roundtrip_get_indexed_field_list_range() -> Result<()> { + let fields = vec![ + Field::new("id", DataType::Int64, true), + Field::new_list("arg", Field::new("item", DataType::Float64, true), true), + Field::new("start", DataType::Int64, true), + Field::new("stop", DataType::Int64, true), + ]; + + let schema = Schema::new(fields); + let input = Arc::new(EmptyExec::new(false, Arc::new(schema.clone()))); + + let col_arg = col("arg", &schema)?; + let col_start = col("start", &schema)?; + let col_stop = col("stop", &schema)?; + let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( + col_arg, + GetFieldAccessExpr::ListRange { + start: col_start, + stop: col_stop, + }, + )); let plan = Arc::new(ProjectionExec::try_new( vec![(get_indexed_field_expr, "result".to_string())], diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index aaf3569d1634..45c38b5ad5e0 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -48,7 +48,9 @@ use crate::protobuf::{ ScalarValue, }; use datafusion::logical_expr::BuiltinScalarFunction; -use datafusion::physical_expr::expressions::{DateTimeIntervalExpr, GetIndexedFieldExpr}; +use datafusion::physical_expr::expressions::{ + DateTimeIntervalExpr, GetFieldAccessExpr, GetIndexedFieldExpr, +}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::joins::utils::JoinSide; use datafusion::physical_plan::udaf::AggregateFunctionExpr; @@ -389,12 +391,31 @@ impl TryFrom> for protobuf::PhysicalExprNode { )), }) } else if let Some(expr) = expr.downcast_ref::() { + let field = match expr.field() { + GetFieldAccessExpr::NamedStructField{name} => Some( + protobuf::physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(protobuf::NamedStructFieldExpr { + name: Some(ScalarValue::try_from(name)?) + }) + ), + GetFieldAccessExpr::ListIndex{key} => Some( + protobuf::physical_get_indexed_field_expr_node::Field::ListIndexExpr(Box::new(protobuf::ListIndexExpr { + key: Some(Box::new(key.to_owned().try_into()?)) + })) + ), + GetFieldAccessExpr::ListRange{start, stop} => Some( + protobuf::physical_get_indexed_field_expr_node::Field::ListRangeExpr(Box::new(protobuf::ListRangeExpr { + start: Some(Box::new(start.to_owned().try_into()?)), + stop: Some(Box::new(stop.to_owned().try_into()?)), + })) + ), + }; + Ok(protobuf::PhysicalExprNode { expr_type: Some( protobuf::physical_expr_node::ExprType::GetIndexedFieldExpr( Box::new(protobuf::PhysicalGetIndexedFieldExprNode { arg: Some(Box::new(expr.arg().to_owned().try_into()?)), - key: Some(ScalarValue::try_from(expr.key())?), + field, }), ), ), diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 59b7921d672d..8a12cc32b641 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -17,7 +17,7 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, @@ -108,10 +108,7 @@ impl ContextProvider for MySchemaProvider { fn get_table_provider(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { Some(table) => Ok(table.clone()), - _ => Err(DataFusionError::Plan(format!( - "Table not found: {}", - name.table() - ))), + _ => plan_err!("Table not found: {}", name.table()), } } diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 46957a9cdd86..549d46c5e277 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -23,6 +23,7 @@ use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::plan_err; use datafusion_expr::{Expr, ExprSchemable}; pub const ARROW_CAST_NAME: &str = "arrow_cast"; @@ -51,10 +52,7 @@ pub const ARROW_CAST_NAME: &str = "arrow_cast"; /// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result { if args.len() != 2 { - return Err(DataFusionError::Plan(format!( - "arrow_cast needs 2 arguments, {} provided", - args.len() - ))); + return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } let arg1 = args.pop().unwrap(); let arg0 = args.pop().unwrap(); @@ -63,9 +61,9 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { v } else { - return Err(DataFusionError::Plan(format!( + return plan_err!( "arrow_cast requires its second argument to be a constant string, got {arg1}" - ))); + ); }; // do the actual lookup to the appropriate data type diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index f08f357ec42c..362a2ac42a83 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -16,6 +16,7 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_common::plan_err; use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; use datafusion_expr::function::suggest_valid_function; @@ -65,9 +66,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // required ordering should be defined in OVER clause. let is_function_window = function.over.is_some(); if !function.order_by.is_empty() && is_function_window { - return Err(DataFusionError::Plan( - "Aggregate ORDER BY is not implemented for window functions".to_string(), - )); + return plan_err!( + "Aggregate ORDER BY is not implemented for window functions" + ); } // then, window function @@ -160,9 +161,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Could not find the relevant function, so return an error let suggested_func_name = suggest_valid_function(&name, is_function_window); - Err(DataFusionError::Plan(format!( - "Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?" - ))) + plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") } pub(super) fn sql_named_function_to_expr( diff --git a/datafusion/sql/src/expr/grouping_set.rs b/datafusion/sql/src/expr/grouping_set.rs index 34dfac158aa0..254f5079b7b1 100644 --- a/datafusion/sql/src/expr/grouping_set.rs +++ b/datafusion/sql/src/expr/grouping_set.rs @@ -16,6 +16,7 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_common::plan_err; use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::{Expr, GroupingSet}; use sqlparser::ast::Expr as SQLExpr; @@ -48,10 +49,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .into_iter() .map(|v| { if v.len() != 1 { - Err(DataFusionError::Plan( + plan_err!( "Tuple expressions are not supported for Rollup expressions" - .to_string(), - )) + ) } else { self.sql_expr_to_logical_expr(v[0].clone(), schema, planner_context) } @@ -70,10 +70,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .into_iter() .map(|v| { if v.len() != 1 { - Err(DataFusionError::Plan( - "Tuple expressions not are supported for Cube expressions" - .to_string(), - )) + plan_err!("Tuple expressions not are supported for Cube expressions") } else { self.sql_expr_to_logical_expr(v[0].clone(), schema, planner_context) } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index c18587d8340c..94faa08e51b0 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -17,9 +17,9 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ - Column, DFField, DFSchema, DataFusionError, Result, ScalarValue, TableReference, + Column, DFField, DFSchema, DataFusionError, Result, TableReference, }; -use datafusion_expr::{Case, Expr, GetIndexedField}; +use datafusion_expr::{Case, Expr}; use sqlparser::ast::{Expr as SQLExpr, Ident}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -136,10 +136,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))); } let nested_name = nested_names[0].to_string(); - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(Expr::Column(field.qualified_column())), - ScalarValue::Utf8(Some(nested_name)), - ))) + Ok(Expr::Column(field.qualified_column()).field(nested_name)) } // found matching field with no spare identifier(s) Some((field, _nested_names)) => { diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 0630bcb7e84d..aad9f770ff54 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,15 +29,16 @@ mod value; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; +use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr::{InList, Placeholder}; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, - Expr, ExprSchemable, GetIndexedField, Like, Operator, TryCast, + Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, }; -use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, TrimWhereField, Value}; +use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; use sqlparser::parser::ParserError::ParserError; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -182,7 +183,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::MapAccess { column, keys } => { if let SQLExpr::Identifier(id) = *column { - plan_indexed(col(self.normalizer.normalize(id)), keys) + self.plan_indexed(col(self.normalizer.normalize(id)), keys, schema, planner_context) } else { Err(DataFusionError::NotImplemented(format!( "map access requires an identifier, found column {column} instead" @@ -192,7 +193,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::ArrayIndex { obj, indexes } => { let expr = self.sql_expr_to_logical_expr(*obj, schema, planner_context)?; - plan_indexed(expr, indexes) + self.plan_indexed(expr, indexes, schema, planner_context) } SQLExpr::CompoundIdentifier(ids) => self.sql_compound_identifier_to_expr(ids, schema, planner_context), @@ -414,9 +415,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { - return Err(DataFusionError::Plan( - "Invalid pattern in LIKE expression".to_string(), - )); + return plan_err!("Invalid pattern in LIKE expression"); } Ok(Expr::Like(Like::new( negated, @@ -439,9 +438,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let pattern = self.sql_expr_to_logical_expr(pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { - return Err(DataFusionError::Plan( - "Invalid pattern in SIMILAR TO expression".to_string(), - )); + return plan_err!("Invalid pattern in SIMILAR TO expression"); } Ok(Expr::SimilarTo(Like::new( negated, @@ -503,12 +500,76 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )?)), order_by, ))), - _ => Err(DataFusionError::Plan( + _ => plan_err!( "AggregateExpressionWithFilter expression was not an AggregateFunction" - .to_string(), - )), + ), } } + + fn plan_indices( + &self, + expr: SQLExpr, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let field = match expr.clone() { + SQLExpr::Value( + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), + ) => GetFieldAccess::NamedStructField { + name: ScalarValue::Utf8(Some(s)), + }, + SQLExpr::JsonAccess { + left, + operator: JsonOperator::Colon, + right, + } => { + let start = Box::new(self.sql_expr_to_logical_expr( + *left, + schema, + planner_context, + )?); + let stop = Box::new(self.sql_expr_to_logical_expr( + *right, + schema, + planner_context, + )?); + + GetFieldAccess::ListRange { start, stop } + } + _ => GetFieldAccess::ListIndex { + key: Box::new(self.sql_expr_to_logical_expr( + expr, + schema, + planner_context, + )?), + }, + }; + + Ok(field) + } + + fn plan_indexed( + &self, + expr: Expr, + mut keys: Vec, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let indices = keys.pop().ok_or_else(|| { + ParserError("Internal error: Missing index key expression".to_string()) + })?; + + let expr = if !keys.is_empty() { + self.plan_indexed(expr, keys, schema, planner_context)? + } else { + expr + }; + + Ok(Expr::GetIndexedField(GetIndexedField::new( + Box::new(expr), + self.plan_indices(indices, schema, planner_context)?, + ))) + } } // modifies expr if it is a placeholder with datatype of right @@ -544,42 +605,6 @@ fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result { }) } -fn plan_key(key: SQLExpr) -> Result { - let scalar = match key { - SQLExpr::Value(Value::Number(s, _)) => ScalarValue::Int64(Some( - s.parse() - .map_err(|_| ParserError(format!("Cannot parse {s} as i64.")))?, - )), - SQLExpr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => { - ScalarValue::Utf8(Some(s)) - } - _ => { - return Err(DataFusionError::SQL(ParserError(format!( - "Unsupported index key expression: {key:?}" - )))); - } - }; - - Ok(scalar) -} - -fn plan_indexed(expr: Expr, mut keys: Vec) -> Result { - let key = keys.pop().ok_or_else(|| { - ParserError("Internal error: Missing index key expression".to_string()) - })?; - - let expr = if !keys.is_empty() { - plan_indexed(expr, keys)? - } else { - expr - }; - - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(expr), - plan_key(key)?, - ))) -} - #[cfg(test)] mod tests { use super::*; @@ -628,10 +653,7 @@ mod tests { ) -> Result> { match self.tables.get(name.table()) { Some(table) => Ok(table.clone()), - _ => Err(DataFusionError::Plan(format!( - "Table not found: {}", - name.table() - ))), + _ => plan_err!("Table not found: {}", name.table()), } } diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index 1a9526661542..d95a25b8fad9 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -16,6 +16,7 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_common::plan_err; use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{BuiltinScalarFunction, Expr}; @@ -62,9 +63,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { substring_for: None, }; - return Err(DataFusionError::Plan(format!( - "Substring without for/from is not valid {orig_sql:?}" - ))); + return plan_err!("Substring without for/from is not valid {orig_sql:?}"); } }; diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index c34b6571812d..642ac5d8d0e9 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -18,7 +18,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; use arrow_schema::DataType; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::{plan_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Placeholder}; use datafusion_expr::{lit, Expr, Operator}; use log::debug; @@ -44,14 +44,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Some(v) = try_decode_hex_literal(&s) { Ok(lit(v)) } else { - Err(DataFusionError::Plan(format!( - "Invalid HexStringLiteral '{s}'" - ))) + plan_err!("Invalid HexStringLiteral '{s}'") } } - _ => Err(DataFusionError::Plan(format!( - "Unsupported Value '{value:?}'", - ))), + _ => plan_err!("Unsupported Value '{value:?}'"), } } @@ -104,15 +100,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let index = param[1..].parse::(); let idx = match index { Ok(0) => { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Invalid placeholder, zero is not a valid index: {param}" - ))); + ); } Ok(index) => index - 1, Err(_) => { - return Err(DataFusionError::Plan(format!( - "Invalid placeholder, not a number: {param}" - ))); + return plan_err!("Invalid placeholder, not a number: {param}"); } }; // Check if the placeholder is in the parameter list diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index efa3077c3b59..0f54638d746d 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -29,6 +29,7 @@ use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; use datafusion_common::config::ConfigOptions; +use datafusion_common::plan_err; use datafusion_common::{unqualified_field_not_found, DFSchema, DataFusionError, Result}; use datafusion_common::{OwnedTableReference, TableReference}; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; @@ -243,11 +244,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if idents.is_empty() { Ok(plan) } else if idents.len() != plan.schema().fields().len() { - Err(DataFusionError::Plan(format!( + plan_err!( "Source table contains {} columns but only {} names given as column alias", plan.schema().fields().len(), - idents.len(), - ))) + idents.len() + ) } else { let fields = plan.schema().fields().clone(); LogicalPlanBuilder::from(plan) @@ -461,10 +462,7 @@ pub(crate) fn idents_to_table_reference( let catalog = taker.take(enable_normalization); Ok(OwnedTableReference::full(catalog, schema, table)) } - _ => Err(DataFusionError::Plan(format!( - "Unsupported compound identifier '{:?}'", - taker.0, - ))), + _ => plan_err!("Unsupported compound identifier '{:?}'", taker.0), } } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 34b24b0594fa..272998e4a8de 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -27,6 +27,7 @@ use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, }; +use datafusion_common::plan_err; use sqlparser::parser::ParserError::ParserError; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -117,15 +118,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )? { Expr::Literal(ScalarValue::Int64(Some(s))) => { if s < 0 { - return Err(DataFusionError::Plan(format!( - "Offset must be >= 0, '{s}' was provided." - ))); + return plan_err!("Offset must be >= 0, '{s}' was provided."); } Ok(s as usize) } - _ => Err(DataFusionError::Plan( - "Unexpected expression in OFFSET clause".to_string(), - )), + _ => plan_err!("Unexpected expression in OFFSET clause"), }?, _ => 0, }; @@ -142,9 +139,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::Literal(ScalarValue::Int64(Some(n))) if n >= 0 => { Ok(n as usize) } - _ => Err(DataFusionError::Plan( - "LIMIT must not be negative".to_string(), - )), + _ => plan_err!("LIMIT must not be negative"), }?; Some(n) } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 59656587fa79..a4e6ae007faa 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -24,6 +24,7 @@ use crate::utils::{ resolve_columns, resolve_positions_to_exprs, }; +use datafusion_common::plan_err; use datafusion_common::{ get_target_functional_dependencies, DFSchemaRef, DataFusionError, Result, }; @@ -170,8 +171,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )? } else { match having_expr_opt { - Some(having_expr) => return Err(DataFusionError::Plan( - format!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"))), + Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"), None => (plan, select_exprs, having_expr_opt) } }; @@ -358,9 +358,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Self::check_wildcard_options(&options)?; if empty_from { - return Err(DataFusionError::Plan( - "SELECT * with no tables specified is not valid".to_string(), - )); + return plan_err!("SELECT * with no tables specified is not valid"); } // do not expand from outer schema expand_wildcard(plan.schema().as_ref(), plan, Some(options)) @@ -521,10 +519,10 @@ fn check_conflicting_windows(window_defs: &[NamedWindowDefinition]) -> Result<() for (i, window_def_i) in window_defs.iter().enumerate() { for window_def_j in window_defs.iter().skip(i + 1) { if window_def_i.0 == window_def_j.0 { - return Err(DataFusionError::Plan(format!( + return plan_err!( "The window {} is defined multiple times!", window_def_i.0 - ))); + ); } } } @@ -553,9 +551,7 @@ fn match_window_definitions( } // All named windows must be defined with a WindowSpec. if let Some(WindowType::NamedWindow(ident)) = &f.over { - return Err(DataFusionError::Plan(format!( - "The window {ident} is not defined!" - ))); + return plan_err!("The window {ident} is not defined!"); } } } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 4af32337f77a..ad66640efa14 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -53,6 +53,7 @@ use sqlparser::ast::{ }; use sqlparser::parser::ParserError::ParserError; +use datafusion_common::plan_err; use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; @@ -140,11 +141,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = if !columns.is_empty() { let schema = self.build_schema(columns)?.to_dfschema_ref()?; if schema.fields().len() != input_schema.fields().len() { - return Err(DataFusionError::Plan(format!( + return plan_err!( "Mismatch: {} columns specified, but result has {} columns", schema.fields().len(), input_schema.fields().len() - ))); + ); } let input_fields = input_schema.fields(); let project_exprs = schema @@ -356,42 +357,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { returning, } => { if or.is_some() { - Err(DataFusionError::Plan( - "Inserts with or clauses not supported".to_owned(), - ))?; - } - if overwrite { - Err(DataFusionError::Plan( - "Insert overwrite is not supported".to_owned(), - ))?; + plan_err!("Inserts with or clauses not supported")?; } if partitioned.is_some() { - Err(DataFusionError::Plan( - "Partitioned inserts not yet supported".to_owned(), - ))?; + plan_err!("Partitioned inserts not yet supported")?; } if !after_columns.is_empty() { - Err(DataFusionError::Plan( - "After-columns clause not supported".to_owned(), - ))?; + plan_err!("After-columns clause not supported")?; } if table { - Err(DataFusionError::Plan( - "Table clause not supported".to_owned(), - ))?; + plan_err!("Table clause not supported")?; } if on.is_some() { - Err(DataFusionError::Plan( - "Insert-on clause not supported".to_owned(), - ))?; + plan_err!("Insert-on clause not supported")?; } if returning.is_some() { - Err(DataFusionError::Plan( - "Insert-returning clause not yet supported".to_owned(), - ))?; + plan_err!("Insert-returning clause not supported")?; } let _ = into; // optional keyword doesn't change behavior - self.insert_to_plan(table_name, columns, source) + self.insert_to_plan(table_name, columns, source, overwrite) } Statement::Update { @@ -402,9 +386,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { returning, } => { if returning.is_some() { - Err(DataFusionError::Plan( - "Update-returning clause not yet supported".to_owned(), - ))?; + plan_err!("Update-returning clause not yet supported")?; } self.update_to_plan(table, assignments, from, selection) } @@ -417,20 +399,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { from, } => { if !tables.is_empty() { - return Err(DataFusionError::NotImplemented( - "DELETE not supported".to_string(), - )); + plan_err!("DELETE
not supported")?; } if using.is_some() { - Err(DataFusionError::Plan( - "Using clause not supported".to_owned(), - ))?; + plan_err!("Using clause not supported")?; } + if returning.is_some() { - Err(DataFusionError::Plan( - "Delete-returning clause not yet supported".to_owned(), - ))?; + plan_err!("Delete-returning clause not yet supported")?; } let table_name = self.get_delete_target(from)?; self.delete_to_plan(table_name, selection) @@ -541,9 +518,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // we only support the basic "SHOW TABLES" // https://github.com/apache/arrow-datafusion/issues/3188 if db_name.is_some() || filter.is_some() || full || extended { - Err(DataFusionError::Plan( - "Unsupported parameters to SHOW TABLES".to_string(), - )) + plan_err!("Unsupported parameters to SHOW TABLES") } else { let query = "SELECT * FROM information_schema.tables;"; let mut rewrite = DFParser::parse_sql(query)?; @@ -551,10 +526,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.statement_to_plan(rewrite.pop_front().unwrap()) // length of rewrite is 1 } } else { - Err(DataFusionError::Plan( - "SHOW TABLES is not supported unless information_schema is enabled" - .to_string(), - )) + plan_err!("SHOW TABLES is not supported unless information_schema is enabled") } } @@ -590,10 +562,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result>> { // Ask user to provide a schema if schema is empty. if !order_exprs.is_empty() && schema.fields().is_empty() { - return Err(DataFusionError::Plan( + return plan_err!( "Provide a schema before specifying the order while creating a table." - .to_owned(), - )); + ); } let mut all_results = vec![]; @@ -605,9 +576,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { for column in expr.to_columns()?.iter() { if !schema.has_column(column) { // Return an error if any column is not in the schema: - return Err(DataFusionError::Plan(format!( - "Column {column} is not in schema" - ))); + return plan_err!("Column {column} is not in schema"); } } } @@ -640,18 +609,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // semantic checks if file_type == "PARQUET" && !columns.is_empty() { - Err(DataFusionError::Plan( - "Column definitions can not be specified for PARQUET files.".into(), - ))?; + plan_err!("Column definitions can not be specified for PARQUET files.")?; } if file_type != "CSV" && file_type != "JSON" && file_compression_type != CompressionTypeVariant::UNCOMPRESSED { - Err(DataFusionError::Plan( - "File compression type can be specified for CSV/JSON files.".into(), - ))?; + plan_err!("File compression type can be specified for CSV/JSON files.")?; } let schema = self.build_schema(columns)?; @@ -718,10 +683,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let variable = object_name_to_string(&ObjectName(variable.to_vec())); if !self.has_table("information_schema", "df_settings") { - return Err(DataFusionError::Plan( + return plan_err!( "SHOW [VARIABLE] is not supported unless information_schema is enabled" - .to_string(), - )); + ); } let variable_lower = variable.to_lowercase(); @@ -790,10 +754,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | Value::HexStringLiteral(_) | Value::Null | Value::Placeholder(_) => { - return Err(DataFusionError::Plan(format!( - "Unsupported Value {}", - value[0] - ))); + return plan_err!("Unsupported Value {}", value[0]); } }, // for capture signed number e.g. +8, -8 @@ -801,17 +762,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { UnaryOperator::Plus => format!("+{expr}"), UnaryOperator::Minus => format!("-{expr}"), _ => { - return Err(DataFusionError::Plan(format!( - "Unsupported Value {}", - value[0] - ))); + return plan_err!("Unsupported Value {}", value[0]); } }, _ => { - return Err(DataFusionError::Plan(format!( - "Unsupported Value {}", - value[0] - ))); + return plan_err!("Unsupported Value {}", value[0]); } }; @@ -874,9 +829,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let table_name = match &table.relation { TableFactor::Table { name, .. } => name.clone(), - _ => Err(DataFusionError::Plan( - "Cannot update non-table relation!".to_string(), - ))?, + _ => plan_err!("Cannot update non-table relation!")?, }; // Do a table lookup to verify the table exists @@ -978,6 +931,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_name: ObjectName, columns: Vec, source: Box, + overwrite: bool, ) -> Result { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; @@ -1053,9 +1007,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { PlannerContext::new().with_prepare_param_data_types(prepare_param_data_types); let source = self.query_to_plan(*source, &mut planner_context)?; if fields.len() != source.schema().fields().len() { - Err(DataFusionError::Plan( - "Column count doesn't match insert query!".to_owned(), - ))?; + plan_err!("Column count doesn't match insert query!")?; } let exprs = index_mapping @@ -1073,10 +1025,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()?; let source = project(source, exprs)?; + let op = if overwrite { + WriteOp::InsertOverwrite + } else { + WriteOp::InsertInto + }; + let plan = LogicalPlan::Dml(DmlStatement { table_name, table_schema: Arc::new(table_schema), - op: WriteOp::Insert, + op, input: Arc::new(source), }); Ok(plan) @@ -1090,16 +1048,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { filter: Option, ) -> Result { if filter.is_some() { - return Err(DataFusionError::Plan( - "SHOW COLUMNS with WHERE or LIKE is not supported".to_string(), - )); + return plan_err!("SHOW COLUMNS with WHERE or LIKE is not supported"); } if !self.has_table("information_schema", "columns") { - return Err(DataFusionError::Plan( + return plan_err!( "SHOW COLUMNS is not supported unless information_schema is enabled" - .to_string(), - )); + ); } // Figure out the where clause let where_clause = object_name_to_qualifier( @@ -1132,10 +1087,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { sql_table_name: ObjectName, ) -> Result { if !self.has_table("information_schema", "tables") { - return Err(DataFusionError::Plan( + return plan_err!( "SHOW CREATE TABLE is not supported unless information_schema is enabled" - .to_string(), - )); + ); } // Figure out the where clause let where_clause = object_name_to_qualifier( diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 27e62526b859..5a570f390377 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -20,6 +20,7 @@ use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE}; use sqlparser::ast::Ident; +use datafusion_common::plan_err; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{ AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, GetIndexedField, @@ -121,12 +122,12 @@ fn check_column_satisfies_expr( message_prefix: &str, ) -> Result<()> { if !columns.contains(expr) { - return Err(DataFusionError::Plan(format!( + return plan_err!( "{}: Expression {} could not be resolved from available columns: {}", message_prefix, expr, expr_vec_fmt!(columns) - ))); + ); } Ok(()) } @@ -376,10 +377,10 @@ where ))), Expr::Wildcard => Ok(Expr::Wildcard), Expr::QualifiedWildcard { .. } => Ok(expr.clone()), - Expr::GetIndexedField(GetIndexedField { key, expr }) => { + Expr::GetIndexedField(GetIndexedField { expr, field }) => { Ok(Expr::GetIndexedField(GetIndexedField::new( Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?), - key.clone(), + field.clone(), ))) } Expr::GroupingSet(set) => match set { @@ -511,9 +512,7 @@ pub(crate) fn make_decimal_type( (Some(p), Some(s)) => (p as u8, s as i8), (Some(p), None) => (p as u8, 0), (None, Some(_)) => { - return Err(DataFusionError::Plan( - "Cannot specify only scale for decimal data type".to_string(), - )) + return plan_err!("Cannot specify only scale for decimal data type") } (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE), }; @@ -523,9 +522,9 @@ pub(crate) fn make_decimal_type( || precision > DECIMAL128_MAX_PRECISION || scale.unsigned_abs() > precision { - Err(DataFusionError::Plan(format!( + plan_err!( "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 38`, and `scale <= precision`." - ))) + ) } else { Ok(DataType::Decimal128(precision, scale)) } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 26663f88c3a6..eef9093947fb 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,6 +22,7 @@ use std::{sync::Arc, vec}; use arrow_schema::*; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; +use datafusion_common::plan_err; use datafusion_common::{ assert_contains, config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, @@ -329,7 +330,7 @@ fn plan_insert() { let sql = "insert into person (id, first_name, last_name) values (1, 'Alan', 'Turing')"; let plan = r#" -Dml: op=[Insert] table=[person] +Dml: op=[Insert Into] table=[person] Projection: CAST(column1 AS UInt32) AS id, column2 AS first_name, column3 AS last_name Values: (Int64(1), Utf8("Alan"), Utf8("Turing")) "# @@ -341,7 +342,7 @@ Dml: op=[Insert] table=[person] fn plan_insert_no_target_columns() { let sql = "INSERT INTO test_decimal VALUES (1, 2), (3, 4)"; let plan = r#" -Dml: op=[Insert] table=[test_decimal] +Dml: op=[Insert Into] table=[test_decimal] Projection: CAST(column1 AS Int32) AS id, CAST(column2 AS Decimal128(10, 2)) AS price Values: (Int64(1), Int64(2)), (Int64(3), Int64(4)) "# @@ -2695,10 +2696,7 @@ impl ContextProvider for MockContextProvider { Field::new("Id", DataType::UInt32, false), Field::new("lower", DataType::UInt32, false), ])), - _ => Err(DataFusionError::Plan(format!( - "No table named: {} found", - name.table() - ))), + _ => plan_err!("No table named: {} found", name.table()), }; match schema { @@ -3882,7 +3880,7 @@ fn test_prepare_statement_insert_infer() { let sql = "insert into person (id, first_name, last_name) values ($1, $2, $3)"; let expected_plan = r#" -Dml: op=[Insert] table=[person] +Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name Values: ($1, $2, $3) "# @@ -3906,7 +3904,7 @@ Dml: op=[Insert] table=[person] ScalarValue::Utf8(Some("Turing".to_string())), ]; let expected_plan = r#" -Dml: op=[Insert] table=[person] +Dml: op=[Insert Into] table=[person] Projection: column1 AS id, column2 AS first_name, column3 AS last_name Values: (UInt32(1), Utf8("Alan"), Utf8("Turing")) "# diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml new file mode 100644 index 000000000000..66961abdfaca --- /dev/null +++ b/datafusion/sqllogictest/Cargo.toml @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +name = "datafusion-sqllogictest" +readme.workspace = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +[lib] +name = "datafusion_sqllogictest" +path = "src/lib.rs" + +[dependencies] +arrow = {workspace = true} +async-trait = "0.1.41" +bigdecimal = "0.4.1" +datafusion = {path = "../core", version = "28.0.0"} +datafusion-common = {path = "../common", version = "28.0.0"} +half = "2.2.1" +itertools = "0.11" +lazy_static = {version = "^1.4.0"} +object_store = "0.6.1" +rust_decimal = {version = "1.27.0"} +log = "^0.4" +sqllogictest = "0.15.0" +sqlparser.workspace = true +thiserror = "1.0.44" +tokio = {version = "1.0"} +bytes = {version = "1.4.0", optional = true} +futures = {version = "0.3.28", optional = true} +chrono = {version = "0.4.26", optional = true} +tokio-postgres = {version = "0.7.7", optional = true} +postgres-types = {version = "0.2.4", optional = true} +postgres-protocol = {version = "0.6.4", optional = true} + +[features] +postgres = ["bytes", "futures", "chrono", "tokio-postgres", "postgres-types", "postgres-protocol"] diff --git a/datafusion/core/tests/sqllogictests/src/engines/conversion.rs b/datafusion/sqllogictest/src/engines/conversion.rs similarity index 82% rename from datafusion/core/tests/sqllogictests/src/engines/conversion.rs rename to datafusion/sqllogictest/src/engines/conversion.rs index c069c2d4a48d..a44783b098c9 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/conversion.rs +++ b/datafusion/sqllogictest/src/engines/conversion.rs @@ -19,11 +19,11 @@ use arrow::datatypes::{Decimal128Type, DecimalType}; use bigdecimal::BigDecimal; use half::f16; use rust_decimal::prelude::*; -use rust_decimal::Decimal; +/// Represents a constant for NULL string in your database. pub const NULL_STR: &str = "NULL"; -pub fn bool_to_str(value: bool) -> String { +pub(crate) fn bool_to_str(value: bool) -> String { if value { "true".to_string() } else { @@ -31,7 +31,7 @@ pub fn bool_to_str(value: bool) -> String { } } -pub fn varchar_to_str(value: &str) -> String { +pub(crate) fn varchar_to_str(value: &str) -> String { if value.is_empty() { "(empty)".to_string() } else { @@ -39,7 +39,7 @@ pub fn varchar_to_str(value: &str) -> String { } } -pub fn f16_to_str(value: f16) -> String { +pub(crate) fn f16_to_str(value: f16) -> String { if value.is_nan() { "NaN".to_string() } else if value == f16::INFINITY { @@ -51,7 +51,7 @@ pub fn f16_to_str(value: f16) -> String { } } -pub fn f32_to_str(value: f32) -> String { +pub(crate) fn f32_to_str(value: f32) -> String { if value.is_nan() { "NaN".to_string() } else if value == f32::INFINITY { @@ -63,7 +63,7 @@ pub fn f32_to_str(value: f32) -> String { } } -pub fn f64_to_str(value: f64) -> String { +pub(crate) fn f64_to_str(value: f64) -> String { if value.is_nan() { "NaN".to_string() } else if value == f64::INFINITY { @@ -75,17 +75,17 @@ pub fn f64_to_str(value: f64) -> String { } } -pub fn i128_to_str(value: i128, precision: &u8, scale: &i8) -> String { +pub(crate) fn i128_to_str(value: i128, precision: &u8, scale: &i8) -> String { big_decimal_to_str( BigDecimal::from_str(&Decimal128Type::format_decimal(value, *precision, *scale)) .unwrap(), ) } -pub fn decimal_to_str(value: Decimal) -> String { +pub(crate) fn decimal_to_str(value: Decimal) -> String { big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) } -pub fn big_decimal_to_str(value: BigDecimal) -> String { +pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String { value.round(12).normalized().to_string() } diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/error.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/error.rs similarity index 100% rename from datafusion/core/tests/sqllogictests/src/engines/datafusion/error.rs rename to datafusion/sqllogictest/src/engines/datafusion_engine/error.rs diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs new file mode 100644 index 000000000000..663bbdd5a3c7 --- /dev/null +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// DataFusion engine implementation for sqllogictest. +mod error; +mod normalize; +mod runner; + +pub use error::*; +pub use normalize::*; +pub use runner::*; diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs similarity index 97% rename from datafusion/core/tests/sqllogictests/src/engines/datafusion/normalize.rs rename to datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index 6dd4e17d7dd7..954926ae3310 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -27,7 +27,7 @@ use super::super::conversion::*; use super::error::{DFSqlLogicTestError, Result}; /// Converts `batches` to a result as expected by sqllogicteset. -pub fn convert_batches(batches: Vec) -> Result>> { +pub(crate) fn convert_batches(batches: Vec) -> Result>> { if batches.is_empty() { Ok(vec![]) } else { @@ -113,13 +113,13 @@ fn expand_row(mut row: Vec) -> impl Iterator> { /// normalize path references /// -/// ``` +/// ```text /// CsvExec: files={1 group: [[path/to/datafusion/testing/data/csv/aggregate_test_100.csv]]}, ... /// ``` /// /// into: /// -/// ``` +/// ```text /// CsvExec: files={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, ... /// ``` fn normalize_paths(mut row: Vec) -> Vec { @@ -230,7 +230,7 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { } /// Converts columns to a result as expected by sqllogicteset. -pub fn convert_schema_to_types(columns: &[DFField]) -> Vec { +pub(crate) fn convert_schema_to_types(columns: &[DFField]) -> Vec { columns .iter() .map(|f| f.data_type()) diff --git a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs similarity index 91% rename from datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs rename to datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs index dd30ef494d49..afd0a241ca5e 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs @@ -15,21 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::path::PathBuf; -use std::time::Duration; +use std::{path::PathBuf, time::Duration}; -use crate::engines::output::{DFColumnType, DFOutput}; - -use self::error::{DFSqlLogicTestError, Result}; +use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use datafusion::arrow::record_batch::RecordBatch; use datafusion::prelude::SessionContext; use log::info; use sqllogictest::DBOutput; -mod error; -mod normalize; -mod util; +use super::{error::Result, normalize, DFSqlLogicTestError}; + +use crate::engines::output::{DFColumnType, DFOutput}; pub struct DataFusion { ctx: SessionContext, @@ -61,7 +57,7 @@ impl sqllogictest::AsyncDB for DataFusion { "DataFusion" } - /// [`Runner`] calls this function to perform sleep. + /// [`DataFusion`] calls this function to perform sleep. /// /// The default implementation is `std::thread::sleep`, which is universal to any async runtime /// but would block the current thread. If you are running in tokio runtime, you should override diff --git a/datafusion/sqllogictest/src/engines/mod.rs b/datafusion/sqllogictest/src/engines/mod.rs new file mode 100644 index 000000000000..a6a0886332ed --- /dev/null +++ b/datafusion/sqllogictest/src/engines/mod.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Implementation of sqllogictest for datafusion. +mod conversion; +mod datafusion_engine; +mod output; + +pub use datafusion_engine::DataFusion; + +#[cfg(feature = "postgres")] +mod postgres_engine; + +#[cfg(feature = "postgres")] +pub use postgres_engine::Postgres; diff --git a/datafusion/core/tests/sqllogictests/src/engines/output.rs b/datafusion/sqllogictest/src/engines/output.rs similarity index 97% rename from datafusion/core/tests/sqllogictests/src/engines/output.rs rename to datafusion/sqllogictest/src/engines/output.rs index 0682f5df97c1..24299856e00d 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/output.rs +++ b/datafusion/sqllogictest/src/engines/output.rs @@ -54,4 +54,4 @@ impl ColumnType for DFColumnType { } } -pub type DFOutput = DBOutput; +pub(crate) type DFOutput = DBOutput; diff --git a/datafusion/core/tests/sqllogictests/src/engines/postgres/mod.rs b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs similarity index 99% rename from datafusion/core/tests/sqllogictests/src/engines/postgres/mod.rs rename to datafusion/sqllogictest/src/engines/postgres_engine/mod.rs index 2c6287b97bfd..fe2785603e76 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/postgres/mod.rs +++ b/datafusion/sqllogictest/src/engines/postgres_engine/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +/// Postgres engine implementation for sqllogictest. use std::path::{Path, PathBuf}; use std::str::FromStr; diff --git a/datafusion/core/tests/sqllogictests/src/engines/postgres/types.rs b/datafusion/sqllogictest/src/engines/postgres_engine/types.rs similarity index 100% rename from datafusion/core/tests/sqllogictests/src/engines/postgres/types.rs rename to datafusion/sqllogictest/src/engines/postgres_engine/types.rs diff --git a/datafusion/core/tests/sqllogictests/src/engines/mod.rs b/datafusion/sqllogictest/src/lib.rs similarity index 88% rename from datafusion/core/tests/sqllogictests/src/engines/mod.rs rename to datafusion/sqllogictest/src/lib.rs index a2657bb60017..b739d75777de 100644 --- a/datafusion/core/tests/sqllogictests/src/engines/mod.rs +++ b/datafusion/sqllogictest/src/lib.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -mod conversion; -pub mod datafusion; -mod output; -pub mod postgres; +mod engines; + +pub use engines::DataFusion; + +#[cfg(feature = "postgres")] +pub use engines::Postgres; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 2c37f7ede100..4e4d71ddb604 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -53,6 +53,7 @@ use substrait::proto::{ }; use substrait::proto::{FunctionArgument, SortField}; +use datafusion::common::plan_err; use datafusion::logical_expr::expr::{InList, Sort}; use std::collections::HashMap; use std::str::FromStr; @@ -165,7 +166,7 @@ pub async fn from_substrait_plan( Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?) } }, - None => Err(DataFusionError::Plan("Cannot parse plan relation: None".to_string())) + None => plan_err!("Cannot parse plan relation: None") } }, _ => Err(DataFusionError::NotImplemented(format!( @@ -362,8 +363,7 @@ pub async fn from_substrait_rel( // TODO: collect only one null_eq_null let join_exprs: Vec<(Column, Column, bool)> = predicates .iter() - .map(|p| { - match p { + .map(|p| match p { Expr::BinaryExpr(BinaryExpr { left, op, right }) => { match (left.as_ref(), right.as_ref()) { (Expr::Column(l), Expr::Column(r)) => match op { @@ -371,20 +371,14 @@ pub async fn from_substrait_rel( Operator::IsNotDistinctFrom => { Ok((l.clone(), r.clone(), true)) } - _ => Err(DataFusionError::Plan( - "invalid join condition op".to_string(), - )), + _ => plan_err!("invalid join condition op"), }, - _ => Err(DataFusionError::Plan( - "invalid join condition expression".to_string(), - )), + _ => plan_err!("invalid join condition expression"), } } - _ => Err(DataFusionError::Plan( + _ => plan_err!( "Non-binary expression is not supported in join condition" - .to_string(), - )), - } + ), }) .collect::>>()?; let (left_cols, right_cols, null_eq_nulls): (Vec<_>, Vec<_>, Vec<_>) = @@ -407,9 +401,7 @@ pub async fn from_substrait_rel( join_filter, )? .build(), - None => Err(DataFusionError::Plan( - "Join without join keys require a valid filter".to_string(), - )), + None => plan_err!("Join without join keys require a valid filter"), }, } } @@ -417,9 +409,7 @@ pub async fn from_substrait_rel( Some(ReadType::NamedTable(nt)) => { let table_reference = match nt.names.len() { 0 => { - return Err(DataFusionError::Plan( - "No table name found in NamedTable".to_string(), - )); + return plan_err!("No table name found in NamedTable"); } 1 => TableReference::Bare { table: (&nt.names[0]).into(), @@ -459,9 +449,7 @@ pub async fn from_substrait_rel( )?); Ok(LogicalPlan::TableScan(scan)) } - _ => Err(DataFusionError::Plan( - "unexpected plan for table".to_string(), - )), + _ => plan_err!("unexpected plan for table"), } } _ => Ok(t), @@ -564,14 +552,10 @@ fn from_substrait_jointype(join_type: i32) -> Result { join_rel::JoinType::Outer => Ok(JoinType::Full), join_rel::JoinType::Anti => Ok(JoinType::LeftAnti), join_rel::JoinType::Semi => Ok(JoinType::LeftSemi), - _ => Err(DataFusionError::Plan(format!( - "unsupported join type {substrait_join_type:?}" - ))), + _ => plan_err!("unsupported join type {substrait_join_type:?}"), } } else { - Err(DataFusionError::Plan(format!( - "invalid join type variant {join_type:?}" - ))) + plan_err!("invalid join type variant {join_type:?}") } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 07da2dc0bee1..d3337e736d35 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -471,7 +471,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::Gt => "gt", Operator::GtEq => "gte", Operator::Plus => "add", - Operator::Minus => "substract", + Operator::Minus => "subtract", Operator::Multiply => "multiply", Operator::Divide => "divide", Operator::Modulo => "mod", diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index f297264c3dd5..90c3d199b740 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -410,6 +410,21 @@ async fn roundtrip_outer_join() -> Result<()> { roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a = data2.a").await } +#[tokio::test] +async fn roundtrip_arithmetic_ops() -> Result<()> { + roundtrip("SELECT a - a FROM data").await?; + roundtrip("SELECT a + a FROM data").await?; + roundtrip("SELECT a * a FROM data").await?; + roundtrip("SELECT a / a FROM data").await?; + roundtrip("SELECT a = a FROM data").await?; + roundtrip("SELECT a != a FROM data").await?; + roundtrip("SELECT a > a FROM data").await?; + roundtrip("SELECT a >= a FROM data").await?; + roundtrip("SELECT a < a FROM data").await?; + roundtrip("SELECT a <= a FROM data").await?; + Ok(()) +} + #[tokio::test] async fn roundtrip_like() -> Result<()> { roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await diff --git a/dev/release/README.md b/dev/release/README.md index ac180632367c..1fd062e3d5c0 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -284,7 +284,10 @@ of the following crates: - [datafusion-expr](https://crates.io/crates/datafusion-expr) - [datafusion-physical-expr](https://crates.io/crates/datafusion-physical-expr) - [datafusion-proto](https://crates.io/crates/datafusion-proto) -- [datafusion-row](https://crates.io/crates/datafusion-row) +- [datafusion-execution](https://crates.io/crates/datafusion-execution) +- [datafusion-sql](https://crates.io/crates/datafusion-sql) +- [datafusion-optimizer](https://crates.io/crates/datafusion-optimizer) +- [datafusion-substrait](https://crates.io/crates/datafusion-substrait) Download and unpack the official release tarball @@ -308,7 +311,6 @@ dot -Tsvg dev/release/crate-deps.dot > dev/release/crate-deps.svg (cd datafusion/common && cargo publish) (cd datafusion/expr && cargo publish) (cd datafusion/sql && cargo publish) -(cd datafusion/row && cargo publish) (cd datafusion/physical-expr && cargo publish) (cd datafusion/optimizer && cargo publish) (cd datafusion/execution && cargo publish) @@ -385,15 +387,16 @@ You can include mention crates.io and PyPI version URLs in the email if applicab ``` We have published new versions of DataFusion to crates.io: -https://crates.io/crates/datafusion/8.0.0 -https://crates.io/crates/datafusion-cli/8.0.0 -https://crates.io/crates/datafusion-common/8.0.0 -https://crates.io/crates/datafusion-expr/8.0.0 -https://crates.io/crates/datafusion-optimizer/8.0.0 -https://crates.io/crates/datafusion-physical-expr/8.0.0 -https://crates.io/crates/datafusion-proto/8.0.0 -https://crates.io/crates/datafusion-row/8.0.0 -https://crates.io/crates/datafusion-sql/8.0.0 +https://crates.io/crates/datafusion/28.0.0 +https://crates.io/crates/datafusion-cli/28.0.0 +https://crates.io/crates/datafusion-common/28.0.0 +https://crates.io/crates/datafusion-expr/28.0.0 +https://crates.io/crates/datafusion-optimizer/28.0.0 +https://crates.io/crates/datafusion-physical-expr/28.0.0 +https://crates.io/crates/datafusion-proto/28.0.0 +https://crates.io/crates/datafusion-sql/28.0.0 +https://crates.io/crates/datafusion-execution/28.0.0 +https://crates.io/crates/datafusion-substrait/28.0.0 ``` ### Add the release to Apache Reporter diff --git a/dev/release/crate-deps.dot b/dev/release/crate-deps.dot index 756614d4d344..6d1aa6e5807d 100644 --- a/dev/release/crate-deps.dot +++ b/dev/release/crate-deps.dot @@ -30,8 +30,11 @@ digraph G { datafusion_physical_expr -> datafusion_common datafusion_physical_expr -> datafusion_expr + datafusion_execution -> datafusion_common + datafusion_execution -> datafusion_expr datafusion -> datafusion_common + datafusion -> datafusion_execution datafusion -> datafusion_expr datafusion -> datafusion_optimizer datafusion -> datafusion_physical_expr diff --git a/dev/release/crate-deps.svg b/dev/release/crate-deps.svg index 388e9e7705dc..63f61bde1871 100644 --- a/dev/release/crate-deps.svg +++ b/dev/release/crate-deps.svg @@ -1,157 +1,181 @@ - - + G - + datafusion_common - -datafusion_common + +datafusion_common datafusion_expr - -datafusion_expr + +datafusion_expr datafusion_expr->datafusion_common - - + + datafusion_sql - -datafusion_sql + +datafusion_sql datafusion_sql->datafusion_common - - + + datafusion_sql->datafusion_expr - - + + datafusion_optimizer - -datafusion_optimizer + +datafusion_optimizer datafusion_optimizer->datafusion_common - - + + datafusion_optimizer->datafusion_expr - - + + datafusion_physical_expr - -datafusion_physical_expr + +datafusion_physical_expr datafusion_physical_expr->datafusion_common - - + + datafusion_physical_expr->datafusion_expr - - + + - + +datafusion_execution + +datafusion_execution + + + +datafusion_execution->datafusion_common + + + + + +datafusion_execution->datafusion_expr + + + + + datafusion - -datafusion + +datafusion - + datafusion->datafusion_common - - + + - + datafusion->datafusion_expr - - + + - + datafusion->datafusion_sql - - + + - + datafusion->datafusion_optimizer - - + + - + datafusion->datafusion_physical_expr - - + + + + + +datafusion->datafusion_execution + + - + datafusion_proto - -datafusion_proto + +datafusion_proto - + datafusion_proto->datafusion - - + + - + datafusion_substrait - -datafusion_substrait + +datafusion_substrait - + datafusion_substrait->datafusion - - + + - + datafusion_cli - -datafusion_cli + +datafusion_cli - + datafusion_cli->datafusion - - + + diff --git a/dev/release/release-crates.sh b/dev/release/release-crates.sh index 658ec88b899d..00ce77a86749 100644 --- a/dev/release/release-crates.sh +++ b/dev/release/release-crates.sh @@ -32,11 +32,12 @@ if ! [ git rev-parse --is-inside-work-tree ]; then cd datafusion/common && cargo publish cd datafusion/expr && cargo publish cd datafusion/sql && cargo publish - cd datafusion/row && cargo publish cd datafusion/physical-expr && cargo publish cd datafusion/optimizer && cargo publish cd datafusion/core && cargo publish cd datafusion/proto && cargo publish + cd datafusion/execution && cargo publish + cd datafusion/substrait && cargo publish cd datafusion-cli && cargo publish --no-verify else echo "Crates must be released from the source tarball that was voted on, not from the repo" diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 139e968eccfb..a04f43fd4b2b 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -187,22 +187,24 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_has_all(array, sub-array) | Returns true if all elements of sub-array exist in array `array_has_all([1,2,3], [1,3]) -> true` | | array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | -| array_fill(element, array) | Returns an array filled with copies of the given value. | +| array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | | array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | | array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | | array_prepend(array, element) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | +| array_repeat(element, count) | Returns an array containing element `count` times. `array_repeat(1, 3) -> [1, 1, 1]` | | array_remove(array, element) | Removes the first element from the array equal to the given value. `array_remove([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 2, 3, 2, 1, 4]` | | array_remove_n(array, element, max) | Removes the first `max` elements from the array equal to the given value. `array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2) -> [1, 3, 2, 1, 4]` | | array_remove_all(array, element) | Removes all elements from the array equal to the given value. `array_remove_all([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 3, 1, 4]` | | array_replace(array, from, to) | Replaces the first occurrence of the specified element with another specified element. `array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 2, 3, 2, 1, 4]` | | array_replace_n(array, from, to, max) | Replaces the first `max` occurrences of the specified element with another specified element. `array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2) -> [1, 5, 5, 3, 2, 1, 4]` | | array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | +| array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | | array_to_string(array, delimeter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | -| trim_array(array, n) | Removes the last n elements from the array. | +| trim_array(array, n) | Deprecated | ## Regular Expressions diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 79081c8d3ab4..4304c67f4f1f 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -102,6 +102,7 @@ Here are some active projects using DataFusion: - [Exon](https://github.com/wheretrue/exon) Analysis toolkit for life-science applications - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake - [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database +- [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. - [InfluxDB IOx](https://github.com/influxdata/influxdb_iox) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline - [Parseable](https://github.com/parseablehq/parseable) Log storage and observability platform diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 132ba47e2461..427a7bf130a7 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -245,6 +245,15 @@ last_value(expression [ORDER BY expression]) - [var](#var) - [var_pop](#var_pop) - [var_samp](#var_samp) +- [regr_avgx](#regr_avgx) +- [regr_avgy](#regr_avgy) +- [regr_count](#regr_count) +- [regr_intercept](#regr_intercept) +- [regr_r2](#regr_r2) +- [regr_slope](#regr_slope) +- [regr_sxx](#regr_sxx) +- [regr_syy](#regr_syy) +- [regr_sxy](#regr_sxy) ### `corr` @@ -384,6 +393,142 @@ var_samp(expression) - **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +### `regr_slope` + +Returns the slope of the linear regression line for non-null pairs in aggregate columns. +Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. + +``` +regr_slope(expression1, expression2) +``` + +#### Arguments + +- **expression_y**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_avgx` + +Computes the average of the independent variable (input) `expression_x` for the non-null paired data points. + +``` +regr_avgx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_avgy` + +Computes the average of the dependent variable (output) `expression_y` for the non-null paired data points. + +``` +regr_avgy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_count` + +Counts the number of non-null paired data points. + +``` +regr_count(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_intercept` + +Computes the y-intercept of the linear regression line. For the equation \(y = kx + b\), this function returns `b`. + +``` +regr_intercept(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_r2` + +Computes the square of the correlation coefficient between the independent and dependent variables. + +``` +regr_r2(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_sxx` + +Computes the sum of squares of the independent variable. + +``` +regr_sxx(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_syy` + +Computes the sum of squares of the dependent variable. + +``` +regr_syy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `regr_sxy` + +Computes the sum of products of paired data points. + +``` +regr_sxy(expression_y, expression_x) +``` + +#### Arguments + +- **expression_y**: Dependent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Independent variable. + Can be a constant, column, or function, and any combination of arithmetic operators. + ## Approximate - [approx_distinct](#approx_distinct) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 1e90edc1124d..dec120db18c5 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -43,6 +43,7 @@ - [log](#log) - [log10](#log10) - [log2](#log2) +- [nanvl](#nanvl) - [pi](#pi) - [power](#power) - [pow](#pow) @@ -353,6 +354,22 @@ log2(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +### `nanvl` + +Returns the first argument if it's not _NaN_. +Returns the second argument otherwise. + +``` +nanvl(expression_x, expression_y) +``` + +#### Arguments + +- **expression_x**: Numeric expression to return if it's not _NaN_. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Numeric expression to return if the first expression is _NaN_. + Can be a constant, column, or function, and any combination of arithmetic operators. + ### `pi` Returns an approximate value of π. @@ -396,7 +413,6 @@ radians(numeric_expression) - **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. - ======= ### `random` @@ -1430,7 +1446,8 @@ from_unixtime(expression) - [array_concat](#array_concat) - [array_contains](#array_contains) - [array_dims](#array_dims) -- [array_fill](#array_fill) +- [array_element](#array_element) +- [array_extract](#array_extract) - [array_indexof](#array_indexof) - [array_join](#array_join) - [array_length](#array_length) @@ -1440,18 +1457,22 @@ from_unixtime(expression) - [array_positions](#array_positions) - [array_push_back](#array_push_back) - [array_push_front](#array_push_front) +- [array_repeat](#array_repeat) - [array_remove](#array_remove) - [array_remove_n](#array_remove_n) - [array_remove_all](#array_remove_all) - [array_replace](#array_replace) - [array_replace_n](#array_replace_n) - [array_replace_all](#array_replace_all) +- [array_slice](#array_slice) - [array_to_string](#array_to_string) - [cardinality](#cardinality) - [list_append](#list_append) - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_dims](#list_dims) +- [list_element](#list_element) +- [list_extract](#list_extract) - [list_indexof](#list_indexof) - [list_join](#list_join) - [list_length](#list_length) @@ -1461,12 +1482,14 @@ from_unixtime(expression) - [list_positions](#list_positions) - [list_push_back](#list_push_back) - [list_push_front](#list_push_front) +- [list_repeat](#list_repeat) - [list_remove](#list_remove) - [list_remove_n](#list_remove_n) - [list_remove_all](#list_remove_all) - [list_replace](#list_replace) - [list_replace_n](#list_replace_n) - [list_replace_all](#list_replace_all) +- [list_slice](#list_slice) - [list_to_string](#list_to_string) - [make_array](#make_array) - [make_list](#make_list) @@ -1611,10 +1634,47 @@ array_dims(array) - list_dims +### `array_element` + +Extracts the element with the index n from the array. + +``` +array_element(array, index) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. + +#### Example + +``` +❯ select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +``` + +#### Aliases + +- array_extract +- list_element +- list_extract + +### `array_extract` + +_Alias of [array_element](#array_element)._ + ### `array_fill` Returns an array filled with copies of the given value. +DEPRECATED: use `array_repeat` instead! + ``` array_fill(element, array) ``` @@ -1791,6 +1851,40 @@ _Alias of [array_append](#array_append)._ _Alias of [array_prepend](#array_prepend)._ +### `array_repeat` + +Returns an array containing element `count` times. + +``` +array_repeat(element, count) +``` + +#### Arguments + +- **element**: Element expression. + Can be a constant, column, or function, and any combination of array operators. +- **count**: Value of how many times to repeat the element. + +#### Example + +``` +❯ select array_repeat(1, 3); ++---------------------------------+ +| array_repeat(Int64(1),Int64(3)) | ++---------------------------------+ +| [1, 1, 1] | ++---------------------------------+ +``` + +``` +❯ select array_repeat([1, 2], 2); ++------------------------------------+ +| array_repeat(List([1,2]),Int64(2)) | ++------------------------------------+ +| [[1, 2], [1, 2]] | ++------------------------------------+ +``` + ### `array_remove` Removes the first element from the array equal to the given value. @@ -1816,6 +1910,10 @@ array_remove(array, element) +----------------------------------------------+ ``` +#### Aliases + +- list_remove + ### `array_remove_n` Removes the first `max` elements from the array equal to the given value. @@ -1842,6 +1940,10 @@ array_remove_n(array, element, max) +---------------------------------------------------------+ ``` +#### Aliases + +- list_remove_n + ### `array_remove_all` Removes all elements from the array equal to the given value. @@ -1867,6 +1969,10 @@ array_remove_all(array, element) +--------------------------------------------------+ ``` +#### Aliases + +- list_remove_all + ### `array_replace` Replaces the first occurrence of the specified element with another specified element. @@ -1893,6 +1999,10 @@ array_replace(array, from, to) +--------------------------------------------------------+ ``` +#### Aliases + +- list_replace + ### `array_replace_n` Replaces the first `max` occurrences of the specified element with another specified element. @@ -1920,6 +2030,10 @@ array_replace_n(array, from, to, max) +-------------------------------------------------------------------+ ``` +#### Aliases + +- list_replace_n + ### `array_replace_all` Replaces all occurrences of the specified element with another specified element. @@ -1946,6 +2060,33 @@ array_replace_all(array, from, to) +------------------------------------------------------------+ ``` +#### Aliases + +- list_replace_all + +### `array_slice` + +Returns a slice of the array. + +``` +array_slice(array, begin, end) +``` + +#### Example + +``` +❯ select array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6); ++--------------------------------------------------------+ +| array_slice(List([1,2,3,4,5,6,7,8]),Int64(3),Int64(6)) | ++--------------------------------------------------------+ +| [3, 4, 5, 6] | ++--------------------------------------------------------+ +``` + +#### Aliases + +- list_slice + ### `array_to_string` Converts each element to its text representation. @@ -2017,6 +2158,14 @@ _Alias of [array_concat](#array_concat)._ _Alias of [array_dims](#array_dims)._ +### `list_element` + +_Alias of [array_element](#array_element)._ + +### `list_extract` + +_Alias of [array_element](#array_element)._ + ### `list_indexof` _Alias of [array_position](#array_position)._ @@ -2053,6 +2202,10 @@ _Alias of [array_append](#array_append)._ _Alias of [array_prepend](#array_prepend)._ +### `list_repeat` + +_Alias of [array_repeat](#array_repeat)._ + ### `list_remove` _Alias of [array_remove](#array_remove)._ @@ -2077,6 +2230,10 @@ _Alias of [array_replace_n](#array_replace_n)._ _Alias of [array_replace_all](#array_replace_all)._ +### `list_slice` + +_Alias of [array_slice](#array_slice)._ + ### `list_to_string` _Alias of [list_to_string](#list_to_string)._ @@ -2118,6 +2275,8 @@ _Alias of [make_array](#make_array)._ Removes the last n elements from the array. +DEPRECATED: use `array_slice` instead! + ``` trim_array(array, n) ```