Skip to content

Commit

Permalink
udwf, not udaf
Browse files Browse the repository at this point in the history
  • Loading branch information
berkaysynnada committed Nov 11, 2024
1 parent be02f03 commit 5c4e0f6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ fn window_expr_from_aggregate_expr(
}

/// Creates a `BuiltInWindowFunctionExpr` suitable for a user defined window function
fn create_udwf_window_expr(
pub fn create_udwf_window_expr(
fun: &Arc<WindowUDF>,
args: &[Arc<dyn PhysicalExpr>],
input_schema: &Schema,
Expand Down
31 changes: 16 additions & 15 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ use datafusion::datasource::physical_plan::{
};
use datafusion::execution::FunctionRegistry;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::functions_window::nth_value::nth_value_udwf;
use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility};
use datafusion::physical_expr::expressions::Literal;
use datafusion::physical_expr::window::SlidingAggregateWindowExpr;
use datafusion::physical_expr::window::{BuiltInWindowExpr, SlidingAggregateWindowExpr};
use datafusion::physical_expr::{
LexOrdering, LexRequirement, PhysicalSortRequirement, ScalarFunctionExpr,
};
Expand All @@ -73,7 +74,9 @@ use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::union::{InterleaveExec, UnionExec};
use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec};
use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowAggExec};
use datafusion::physical_plan::windows::{
create_udwf_window_expr, PlainAggregateWindowExpr, WindowAggExec,
};
use datafusion::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr, Statistics};
use datafusion::prelude::SessionContext;
use datafusion::scalar::ScalarValue;
Expand All @@ -85,9 +88,11 @@ use datafusion_common::stats::Precision;
use datafusion_common::{
internal_err, not_impl_err, DataFusionError, Result, UnnestOptions,
};
use datafusion_expr::WindowFunctionDefinition::WindowUDF;
use datafusion_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF,
Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound,
WindowFunctionDefinition,
};
use datafusion_functions_aggregate::average::avg_udaf;
use datafusion_functions_aggregate::nth_value::nth_value_udaf;
Expand All @@ -96,6 +101,7 @@ use datafusion_proto::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
};
use datafusion_proto::protobuf;
use datafusion_proto::protobuf::logical_expr_node::ExprType::WindowExpr;

/// Perform a serde roundtrip and assert that the string representation of the before and after plans
/// are identical. Note that this often isn't sufficient to guarantee that no information is
Expand Down Expand Up @@ -275,25 +281,20 @@ fn roundtrip_window() -> Result<()> {
WindowFrameBound::CurrentRow,
);

let args = vec![cast(col("a", &schema)?, &schema, DataType::Int64)?];
let nth_value_expr = AggregateExprBuilder::new(nth_value_udaf(), args)
.order_by(LexOrdering {
let nth_value_window =
create_udwf_window_expr(&nth_value_udwf(), &[col("a", &schema)?], schema.as_ref(), "FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), false)?;
let builtin_window_expr = Arc::new(BuiltInWindowExpr::new(
nth_value_window,
&[col("b", &schema)?],
&LexOrdering {
inner: vec![PhysicalSortExpr {
expr: col("a", &schema)?,
options: SortOptions {
descending: false,
nulls_first: false,
},
}],
})
.schema(Arc::clone(&schema))
.alias("FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW")
.build()
.map(Arc::new)?;
let sliding_aggr_window_nth_value = Arc::new(SlidingAggregateWindowExpr::new(
nth_value_expr,
&[col("b", &schema)?],
&LexOrdering::default(),
},
Arc::new(window_frame),
));

Expand Down Expand Up @@ -337,7 +338,7 @@ fn roundtrip_window() -> Result<()> {
vec![
plain_aggr_window_expr,
sliding_aggr_window_expr,
sliding_aggr_window_nth_value,
builtin_window_expr,
],
input,
vec![col("b", &schema)?],
Expand Down

0 comments on commit 5c4e0f6

Please sign in to comment.