Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into count-for-all
Browse files Browse the repository at this point in the history
  • Loading branch information
jayzhan211 committed May 16, 2024
2 parents 5765d99 + 5a8348f commit ea81c6e
Show file tree
Hide file tree
Showing 29 changed files with 389 additions and 231 deletions.
7 changes: 2 additions & 5 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
function::{AccumulatorArgs, GroupsAccumulatorSupportedArgs, StateFieldsArgs},
function::{AccumulatorArgs, StateFieldsArgs},
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

Expand Down Expand Up @@ -101,10 +101,7 @@ impl AggregateUDFImpl for GeoMeanUdaf {

/// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator`
/// which is used for cases when there are grouping columns in the query
fn groups_accumulator_supported(
&self,
_args: GroupsAccumulatorSupportedArgs,
) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

Expand Down
9 changes: 2 additions & 7 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::function::{
AggregateFunctionSimplification, GroupsAccumulatorSupportedArgs, StateFieldsArgs,
};
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};
Expand Down Expand Up @@ -76,10 +74,7 @@ impl AggregateUDFImpl for BetterAvgUdaf {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(
&self,
_args: GroupsAccumulatorSupportedArgs,
) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

Expand Down
6 changes: 1 addition & 5 deletions datafusion/core/src/datasource/file_format/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use arrow::ipc::reader::FileReader;
use arrow::ipc::writer::IpcWriteOptions;
use arrow::ipc::{root_as_message, CompressionType};
use arrow_schema::{ArrowError, Schema, SchemaRef};
use datafusion_common::{not_impl_err, DataFusionError, FileType, Statistics};
use datafusion_common::{not_impl_err, DataFusionError, Statistics};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
use datafusion_physical_plan::insert::{DataSink, DataSinkExec};
Expand Down Expand Up @@ -136,10 +136,6 @@ impl FileFormat for ArrowFormat {
order_requirements,
)) as _)
}

fn file_type(&self) -> FileType {
FileType::ARROW
}
}

/// Implements [`DataSink`] for writing to arrow_ipc files
Expand Down
5 changes: 0 additions & 5 deletions datafusion/core/src/datasource/file_format/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use std::sync::Arc;
use arrow::datatypes::Schema;
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use datafusion_common::FileType;
use datafusion_physical_expr::PhysicalExpr;
use object_store::{GetResultPayload, ObjectMeta, ObjectStore};

Expand Down Expand Up @@ -89,10 +88,6 @@ impl FileFormat for AvroFormat {
let exec = AvroExec::new(conf);
Ok(Arc::new(exec))
}

fn file_type(&self) -> FileType {
FileType::AVRO
}
}

#[cfg(test)]
Expand Down
9 changes: 3 additions & 6 deletions datafusion/core/src/datasource/file_format/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use arrow::datatypes::SchemaRef;
use arrow::datatypes::{DataType, Field, Fields, Schema};
use datafusion_common::config::CsvOptions;
use datafusion_common::file_options::csv_writer::CsvWriterOptions;
use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType};
use datafusion_common::{exec_err, not_impl_err, DataFusionError};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
use datafusion_physical_plan::metrics::MetricsSet;
Expand Down Expand Up @@ -280,10 +280,6 @@ impl FileFormat for CsvFormat {
order_requirements,
)) as _)
}

fn file_type(&self) -> FileType {
FileType::CSV
}
}

impl CsvFormat {
Expand Down Expand Up @@ -549,8 +545,9 @@ mod tests {

use arrow::compute::concat_batches;
use datafusion_common::cast::as_string_array;
use datafusion_common::internal_err;
use datafusion_common::stats::Precision;
use datafusion_common::{internal_err, GetExt};
use datafusion_common::{FileType, GetExt};
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::{col, lit};

Expand Down
6 changes: 1 addition & 5 deletions datafusion/core/src/datasource/file_format/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter};
use arrow_array::RecordBatch;
use datafusion_common::config::JsonOptions;
use datafusion_common::file_options::json_writer::JsonWriterOptions;
use datafusion_common::{not_impl_err, FileType};
use datafusion_common::not_impl_err;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
use datafusion_physical_plan::metrics::MetricsSet;
Expand Down Expand Up @@ -184,10 +184,6 @@ impl FileFormat for JsonFormat {
order_requirements,
)) as _)
}

fn file_type(&self) -> FileType {
FileType::JSON
}
}

impl Default for JsonSerializer {
Expand Down
5 changes: 1 addition & 4 deletions datafusion/core/src/datasource/file_format/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use crate::error::Result;
use crate::execution::context::SessionState;
use crate::physical_plan::{ExecutionPlan, Statistics};

use datafusion_common::{not_impl_err, FileType};
use datafusion_common::not_impl_err;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};

use async_trait::async_trait;
Expand Down Expand Up @@ -104,9 +104,6 @@ pub trait FileFormat: Send + Sync + fmt::Debug {
) -> Result<Arc<dyn ExecutionPlan>> {
not_impl_err!("Writer not implemented for this format")
}

/// Returns the FileType corresponding to this FileFormat
fn file_type(&self) -> FileType;
}

#[cfg(test)]
Expand Down
6 changes: 1 addition & 5 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use datafusion_common::config::TableParquetOptions;
use datafusion_common::file_options::parquet_writer::ParquetWriterOptions;
use datafusion_common::stats::Precision;
use datafusion_common::{
exec_err, internal_datafusion_err, not_impl_err, DataFusionError, FileType,
exec_err, internal_datafusion_err, not_impl_err, DataFusionError,
};
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::TaskContext;
Expand Down Expand Up @@ -286,10 +286,6 @@ impl FileFormat for ParquetFormat {
order_requirements,
)) as _)
}

fn file_type(&self) -> FileType {
FileType::PARQUET
}
}

fn summarize_min_max(
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/src/physical_optimizer/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ impl PhysicalOptimizer {
// Remove the ancillary output requirement operator since we are done with the planning
// phase.
Arc::new(OutputRequirements::new_remove_mode()),
// The PipelineChecker rule will reject non-runnable query plans that use
// pipeline-breaking operators on infinite input(s). The rule generates a
// diagnostic error message when this happens. It makes no changes to the
// given query plan; i.e. it only acts as a final gatekeeping rule.
Arc::new(PipelineChecker::new()),
// The aggregation limiter will try to find situations where the accumulator count
// is not tied to the cardinality, i.e. when the output of the aggregation is passed
// into an `order by max(x) limit y`. In this case it will copy the limit value down
Expand All @@ -129,6 +124,11 @@ impl PhysicalOptimizer {
// are not present, the load of executors such as join or union will be
// reduced by narrowing their input tables.
Arc::new(ProjectionPushdown::new()),
// The PipelineChecker rule will reject non-runnable query plans that use
// pipeline-breaking operators on infinite input(s). The rule generates a
// diagnostic error message when this happens. It makes no changes to the
// given query plan; i.e. it only acts as a final gatekeeping rule.
Arc::new(PipelineChecker::new()),
];

Self::with_rules(rules)
Expand Down
34 changes: 3 additions & 31 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::{Field, Schema};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::windows::{
create_window_expr, BoundedWindowAggExec, WindowAggExec,
create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec,
};
use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted};
use datafusion::physical_plan::{collect, InputOrderMode};
Expand All @@ -40,7 +39,6 @@ use datafusion_expr::{
};
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use itertools::Itertools;
use test_utils::add_empty_batches;

use hashbrown::HashMap;
Expand Down Expand Up @@ -276,7 +274,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
};

let extended_schema =
schema_add_window_fields(&args, &schema, &window_fn, fn_name)?;
schema_add_window_field(&args, &schema, &window_fn, fn_name)?;

let window_expr = create_window_expr(
&window_fn,
Expand Down Expand Up @@ -683,7 +681,7 @@ async fn run_window_test(
exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _;
}

let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?;
let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?;

let usual_window_exec = Arc::new(WindowAggExec::try_new(
vec![create_window_expr(
Expand Down Expand Up @@ -754,32 +752,6 @@ async fn run_window_test(
Ok(())
}

// The planner has fully updated schema before calling the `create_window_expr`
// Replicate the same for this test
fn schema_add_window_fields(
args: &[Arc<dyn PhysicalExpr>],
schema: &Arc<Schema>,
window_fn: &WindowFunctionDefinition,
fn_name: &str,
) -> Result<Arc<Schema>> {
let data_types = args
.iter()
.map(|e| e.clone().as_ref().data_type(schema))
.collect::<Result<Vec<_>>>()?;
let window_expr_return_type = window_fn.return_type(&data_types)?;
let mut window_fields = schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect_vec();
window_fields.extend_from_slice(&[Field::new(
fn_name,
window_expr_return_type,
true,
)]);
Ok(Arc::new(Schema::new(window_fields)))
}

/// Return randomly sized record batches with:
/// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns
/// one random int32 column x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,10 +726,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
panic!("accumulator shouldn't invoke");
}

fn groups_accumulator_supported(
&self,
_args: GroupsAccumulatorSupportedArgs,
) -> bool {
fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}

Expand Down
49 changes: 24 additions & 25 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ pub type ReturnTypeFunction =
pub struct AccumulatorArgs<'a> {
/// The return type of the aggregate function.
pub data_type: &'a DataType,

/// The schema of the input arguments
pub schema: &'a Schema,

/// Whether to ignore nulls.
///
/// SQL allows the user to specify `IGNORE NULLS`, for example:
Expand All @@ -67,42 +69,39 @@ pub struct AccumulatorArgs<'a> {
///
/// If no `ORDER BY` is specified, `sort_exprs`` will be empty.
pub sort_exprs: &'a [Expr],

/// Whether the aggregate function is distinct.
///
/// ```sql
/// SELECT COUNT(DISTINCT column1) FROM t;
/// ```
pub is_distinct: bool,
pub input_type: &'a DataType,
}

impl<'a> AccumulatorArgs<'a> {
pub fn new(
data_type: &'a DataType,
schema: &'a Schema,
ignore_nulls: bool,
sort_exprs: &'a [Expr],
is_distinct: bool,
input_type: &'a DataType,
) -> Self {
Self {
data_type,
schema,
ignore_nulls,
sort_exprs,
is_distinct,
input_type,
}
}
}
/// The input type of the aggregate function.
pub input_type: &'a DataType,

/// [`GroupsAccumulatorSupportedArgs`] contains information to determine if an
/// aggregate function supports the groups accumulator.
pub struct GroupsAccumulatorSupportedArgs {
/// The number of arguments the aggregate function takes.
pub args_num: usize,
pub is_distinct: bool,
}

/// [`StateFieldsArgs`] contains information about the fields that an
/// aggregate function's accumulator should have. Used for [`AggregateUDFImpl::state_fields`].
///
/// [`AggregateUDFImpl::state_fields`]: crate::udaf::AggregateUDFImpl::state_fields
pub struct StateFieldsArgs<'a> {
/// The name of the aggregate function.
pub name: &'a str,

/// The input type of the aggregate function.
pub input_type: &'a DataType,

/// The return type of the aggregate function.
pub return_type: &'a DataType,

/// The ordering fields of the aggregate function.
pub ordering_fields: &'a [Field],

/// Whether the aggregate function is distinct.
pub is_distinct: bool,
}

Expand Down
Loading

0 comments on commit ea81c6e

Please sign in to comment.