Skip to content

Commit

Permalink
replace parts of test
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Jul 23, 2024
1 parent 5d98c32 commit 315b8e9
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ pub(crate) mod tests {
// Return appropriate expr depending if COUNT is for col or table (*)
pub(crate) fn count_expr(&self, schema: &Schema) -> Arc<dyn AggregateExpr> {
AggregateExprBuilder::new(count_udaf(), vec![self.column()])
.schema(schema.clone())
.schema(Arc::new(schema.clone()))
.name(self.column_name())
.build()
.unwrap()
Expand Down
14 changes: 7 additions & 7 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub mod stats;
pub mod tdigest;
pub mod utils;

use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
Expand Down Expand Up @@ -75,7 +75,7 @@ pub fn create_aggregate_expr(
builder = builder.sort_exprs(sort_exprs.to_vec());
builder = builder.order_by(ordering_req.to_vec());
builder = builder.logical_exprs(input_exprs.to_vec());
builder = builder.schema(schema.clone());
builder = builder.schema(Arc::new(schema.clone()));
builder = builder.name(name);

if ignore_nulls {
Expand Down Expand Up @@ -109,7 +109,7 @@ pub fn create_aggregate_expr_with_dfschema(
builder = builder.logical_exprs(input_exprs.to_vec());
builder = builder.dfschema(dfschema.clone());
let schema: Schema = dfschema.into();
builder = builder.schema(schema);
builder = builder.schema(Arc::new(schema));
builder = builder.name(name);

if ignore_nulls {
Expand All @@ -134,7 +134,7 @@ pub struct AggregateExprBuilder {
logical_args: Vec<Expr>,
name: String,
/// Arrow Schema for the aggregate function
schema: Schema,
schema: SchemaRef,
/// Datafusion Schema for the aggregate function
dfschema: DFSchema,
/// The logical order by expressions, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
Expand All @@ -156,7 +156,7 @@ impl AggregateExprBuilder {
args,
logical_args: vec![],
name: String::new(),
schema: Schema::empty(),
schema: Arc::new(Schema::empty()),
dfschema: DFSchema::empty(),
sort_exprs: vec![],
ordering_req: vec![],
Expand Down Expand Up @@ -215,7 +215,7 @@ impl AggregateExprBuilder {
logical_args,
data_type,
name,
schema,
schema: Arc::unwrap_or_clone(schema),
dfschema,
sort_exprs,
ordering_req,
Expand All @@ -232,7 +232,7 @@ impl AggregateExprBuilder {
self
}

pub fn schema(mut self, schema: Schema) -> Self {
pub fn schema(mut self, schema: SchemaRef) -> Self {
self.schema = schema;
self
}
Expand Down
60 changes: 19 additions & 41 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1211,7 +1211,7 @@ mod tests {

use crate::common::collect;
use datafusion_physical_expr_common::aggregate::{
create_aggregate_expr, create_aggregate_expr_with_dfschema,
create_aggregate_expr, create_aggregate_expr_with_dfschema, AggregateExprBuilder,
};
use datafusion_physical_expr_common::expressions::Literal;
use futures::{FutureExt, Stream};
Expand Down Expand Up @@ -1351,18 +1351,11 @@ mod tests {
],
};

let aggregates = vec![create_aggregate_expr(
&count_udaf(),
&[lit(1i8)],
&[datafusion_expr::lit(1i8)],
&[],
&[],
&input_schema,
"COUNT(1)",
false,
false,
false,
)?];
let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
.schema(Arc::clone(&input_schema))
.name("COUNT(1)")
.logical_exprs(vec![datafusion_expr::lit(1i8)])
.build()?];

let task_ctx = if spill {
new_spill_ctx(4, 1000)
Expand Down Expand Up @@ -1501,18 +1494,13 @@ mod tests {
groups: vec![vec![false]],
};

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
&avg_udaf(),
&[col("b", &input_schema)?],
&[datafusion_expr::col("b")],
&[],
&[],
&input_schema,
"AVG(b)",
false,
false,
false,
)?];
let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?])
.schema(Arc::clone(&input_schema))
.name("AVG(b)")
.build()?,
];

let task_ctx = if spill {
// set to an appropriate value to trigger spill
Expand Down Expand Up @@ -1803,21 +1791,11 @@ mod tests {
}

// Median(a)
fn test_median_agg_expr(schema: &Schema) -> Result<Arc<dyn AggregateExpr>> {
let args = vec![col("a", schema)?];
let fun = median_udaf();
datafusion_physical_expr_common::aggregate::create_aggregate_expr(
&fun,
&args,
&[],
&[],
&[],
schema,
"MEDIAN(a)",
false,
false,
false,
)
fn test_median_agg_expr(schema: SchemaRef) -> Result<Arc<dyn AggregateExpr>> {
AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
.schema(schema)
.name("MEDIAN(a)")
.build()
}

#[tokio::test]
Expand All @@ -1840,7 +1818,7 @@ mod tests {

// something that allocates within the aggregator
let aggregates_v0: Vec<Arc<dyn AggregateExpr>> =
vec![test_median_agg_expr(&input_schema)?];
vec![test_median_agg_expr(Arc::clone(&input_schema))?];

// use fast-path in `row_hash.rs`.
let aggregates_v2: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ chrono = { workspace = true }
datafusion = { workspace = true, default-features = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
datafusion-proto-common = { workspace = true }
object_store = { workspace = true }
pbjson = { version = "0.6.0", optional = true }
Expand Down
60 changes: 20 additions & 40 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::vec;

use arrow::array::RecordBatch;
use arrow::csv::WriterBuilder;
use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
use prost::Message;

use datafusion::arrow::array::ArrayRef;
Expand Down Expand Up @@ -86,7 +87,7 @@ use datafusion_expr::{
};
use datafusion_functions_aggregate::average::avg_udaf;
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
use datafusion_functions_aggregate::string_agg::StringAgg;
use datafusion_functions_aggregate::string_agg::string_agg_udaf;
use datafusion_proto::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
};
Expand Down Expand Up @@ -357,49 +358,28 @@ fn rountrip_aggregate() -> Result<()> {
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &schema)?, "unused".to_string())];

let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
.schema(Arc::clone(&schema))
.name("AVG(b)")
.build()?;
let nth_expr =
AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?, lit(1u64)])
.schema(Arc::clone(&schema))
.name("NTH_VALUE(b, 1)")
.build()?;
let str_agg_expr =
AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?, lit(1u64)])
.schema(Arc::clone(&schema))
.name("NTH_VALUE(b, 1)")
.build()?;

let test_cases: Vec<Vec<Arc<dyn AggregateExpr>>> = vec![
// AVG
vec![create_aggregate_expr(
&avg_udaf(),
&[col("b", &schema)?],
&[],
&[],
&[],
&schema,
"AVG(b)",
false,
false,
false,
)?],
vec![avg_expr],
// NTH_VALUE
vec![create_aggregate_expr(
&nth_value_udaf(),
&[col("b", &schema)?, lit(1u64)],
&[],
&[],
&[],
&schema,
"NTH_VALUE(b, 1)",
false,
false,
false,
)?],
vec![nth_expr],
// STRING_AGG
vec![create_aggregate_expr(
&AggregateUDF::new_from_impl(StringAgg::new()),
&[
cast(col("b", &schema)?, &schema, DataType::Utf8)?,
lit(ScalarValue::Utf8(Some(",".to_string()))),
],
&[],
&[],
&[],
&schema,
"STRING_AGG(name, ',')",
false,
false,
false,
)?],
vec![str_agg_expr],
];

for aggregates in test_cases {
Expand Down

0 comments on commit 315b8e9

Please sign in to comment.