From d840e987cd855fbbdc5d3e5d69683a5b4f279bb6 Mon Sep 17 00:00:00 2001 From: Sherin Jacob Date: Sat, 16 Nov 2024 00:49:15 +0530 Subject: [PATCH] fix: serialize user-defined window functions to proto (#13421) * Adds roundtrip physical plan test * Adds enum for udwf to `WindowFunction` * initial fix for serializing udwf * Revives deleted test * Adds codec methods for physical plan * Rewrite error message * Minor: rename binding + formatting fixes * Extends `PhysicalExtensionCodec` for udwf * Minor: formatting * Restricts visibility to tests --- datafusion/physical-plan/src/windows/mod.rs | 8 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 13 ++ datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/physical_plan/from_proto.rs | 6 + datafusion/proto/src/physical_plan/mod.rs | 10 +- .../proto/src/physical_plan/to_proto.rs | 25 ++- datafusion/proto/tests/cases/mod.rs | 60 ++++++- .../tests/cases/roundtrip_physical_plan.rs | 160 +++++++++++++++++- 9 files changed, 272 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index a323a958cc76..32173c3ef17d 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -194,7 +194,7 @@ pub fn create_udwf_window_expr( /// Implements [`BuiltInWindowFunctionExpr`] for [`WindowUDF`] #[derive(Clone, Debug)] -struct WindowUDFExpr { +pub struct WindowUDFExpr { fun: Arc, args: Vec>, /// Display name @@ -209,6 +209,12 @@ struct WindowUDFExpr { ignore_nulls: bool, } +impl WindowUDFExpr { + pub fn fun(&self) -> &Arc { + &self.fun + } +} + impl BuiltInWindowFunctionExpr for WindowUDFExpr { fn as_any(&self) -> &dyn std::any::Any { self diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6606b1e93f02..504e5e1ceead 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -853,6 +853,7 @@ message PhysicalWindowExprNode { oneof window_function { // BuiltInWindowFunction built_in_function = 2; string user_defined_aggr_function = 3; + string user_defined_window_function = 10; } repeated PhysicalExprNode args = 4; repeated PhysicalExprNode partition_by = 5; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 09c873b1f98a..29920814a802 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -16326,6 +16326,9 @@ impl serde::Serialize for PhysicalWindowExprNode { physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(v) => { struct_ser.serialize_field("userDefinedAggrFunction", v)?; } + physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(v) => { + struct_ser.serialize_field("userDefinedWindowFunction", v)?; + } } } struct_ser.end() @@ -16350,6 +16353,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "funDefinition", "user_defined_aggr_function", "userDefinedAggrFunction", + "user_defined_window_function", + "userDefinedWindowFunction", ]; #[allow(clippy::enum_variant_names)] @@ -16361,6 +16366,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { Name, FunDefinition, UserDefinedAggrFunction, + UserDefinedWindowFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16389,6 +16395,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { "name" => Ok(GeneratedField::Name), "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), + "userDefinedWindowFunction" | "user_defined_window_function" => Ok(GeneratedField::UserDefinedWindowFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16461,6 +16468,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { } window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_window_expr_node::WindowFunction::UserDefinedAggrFunction); } + GeneratedField::UserDefinedWindowFunction => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("userDefinedWindowFunction")); + } + window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_window_expr_node::WindowFunction::UserDefinedWindowFunction); + } } } Ok(PhysicalWindowExprNode { diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ad5320fc657c..07090b7cba11 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1266,7 +1266,7 @@ pub struct PhysicalWindowExprNode { pub name: ::prost::alloc::string::String, #[prost(bytes = "vec", optional, tag = "9")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, - #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "3")] + #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "3, 10")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, >, @@ -1278,6 +1278,8 @@ pub mod physical_window_expr_node { /// BuiltInWindowFunction built_in_function = 2; #[prost(string, tag = "3")] UserDefinedAggrFunction(::prost::alloc::string::String), + #[prost(string, tag = "10")] + UserDefinedWindowFunction(::prost::alloc::string::String), } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 1c5bdd0c02ba..e528b38b84a8 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -152,6 +152,12 @@ pub fn parse_physical_window_expr( None => registry.udaf(udaf_name)? }) } + protobuf::physical_window_expr_node::WindowFunction::UserDefinedWindowFunction(udwf_name) => { + WindowFunctionDefinition::WindowUDF(match &proto.fun_definition { + Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, + None => registry.udwf(udwf_name)? + }) + } } } else { return Err(proto_error("Missing required field in protobuf")); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 64e462d1695f..292ce13d0ede 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -64,7 +64,7 @@ use datafusion::physical_plan::{ ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; -use datafusion_expr::{AggregateUDF, ScalarUDF}; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::common::{byte_to_string, str_to_byte}; use crate::physical_plan::from_proto::{ @@ -2119,6 +2119,14 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_udwf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("PhysicalExtensionCodec is not provided for window function {name}") + } + + fn try_encode_udwf(&self, _node: &WindowUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug)] diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 60dcd650191d..7d9a524af828 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -19,14 +19,14 @@ use std::sync::Arc; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion::physical_expr::window::SlidingAggregateWindowExpr; +use datafusion::physical_expr::window::{BuiltInWindowExpr, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; -use datafusion::physical_plan::windows::PlainAggregateWindowExpr; +use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowUDFExpr}; use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; use datafusion::{ datasource::{ @@ -68,7 +68,7 @@ pub fn serialize_physical_aggr_expr( ordering_req, distinct: aggr_expr.is_distinct(), ignore_nulls: aggr_expr.ignore_nulls(), - fun_definition: (!buf.is_empty()).then_some(buf) + fun_definition: (!buf.is_empty()).then_some(buf), }, )), }) @@ -120,6 +120,25 @@ pub fn serialize_physical_window_expr( window_frame, codec, )? + } else if let Some(built_in_window_expr) = expr.downcast_ref::() { + if let Some(expr) = built_in_window_expr + .get_built_in_func_expr() + .as_any() + .downcast_ref::() + { + let mut buf = Vec::new(); + codec.try_encode_udwf(expr.fun(), &mut buf)?; + ( + physical_window_expr_node::WindowFunction::UserDefinedWindowFunction( + expr.fun().name().to_string(), + ), + (!buf.is_empty()).then_some(buf), + ) + } else { + return not_impl_err!( + "User-defined window function not supported: {window_expr:?}" + ); + } } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index fbb2cd8f1e83..4d69ca075483 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. +use arrow::datatypes::{DataType, Field}; use std::any::Any; - -use arrow::datatypes::DataType; +use std::fmt::Debug; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, Volatility, + Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl, + Signature, Volatility, WindowUDFImpl, }; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; mod roundtrip_logical_plan; mod roundtrip_physical_plan; @@ -125,3 +128,54 @@ pub struct MyAggregateUdfNode { #[prost(string, tag = "1")] pub result: String, } + +#[derive(Debug)] +pub(in crate::cases) struct CustomUDWF { + signature: Signature, + payload: String, +} + +impl CustomUDWF { + pub fn new(payload: String) -> Self { + Self { + signature: Signature::exact(vec![DataType::Int64], Volatility::Immutable), + payload, + } + } +} + +impl WindowUDFImpl for CustomUDWF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "custom_udwf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> datafusion_common::Result> { + Ok(Box::new(CustomUDWFEvaluator {})) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false)) + } +} + +#[derive(Debug)] +struct CustomUDWFEvaluator; + +impl PartitionEvaluator for CustomUDWFEvaluator {} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub(in crate::cases) struct CustomUDWFNode { + #[prost(string, tag = "1")] + pub payload: String, +} diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index aab63dd8bd66..efa462aa7a85 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -32,7 +32,10 @@ use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; use prost::Message; -use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; +use crate::cases::{ + CustomUDWF, CustomUDWFNode, MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, + MyRegexUdfNode, +}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -47,9 +50,11 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::functions_window::nth_value::nth_value_udwf; +use datafusion::functions_window::row_number::row_number_udwf; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; use datafusion::physical_expr::expressions::Literal; -use datafusion::physical_expr::window::SlidingAggregateWindowExpr; +use datafusion::physical_expr::window::{BuiltInWindowExpr, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{ LexOrdering, LexRequirement, PhysicalSortRequirement, ScalarFunctionExpr, }; @@ -73,8 +78,13 @@ use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; -use datafusion::physical_plan::windows::{PlainAggregateWindowExpr, WindowAggExec}; -use datafusion::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr, Statistics}; +use datafusion::physical_plan::windows::{ + create_udwf_window_expr, BoundedWindowAggExec, PlainAggregateWindowExpr, + WindowAggExec, +}; +use datafusion::physical_plan::{ + ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, Statistics, +}; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_common::config::TableParquetOptions; @@ -87,7 +97,7 @@ use datafusion_common::{ }; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, WindowUDF, }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; @@ -263,12 +273,74 @@ fn roundtrip_nested_loop_join() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_udwf() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let udwf_expr = Arc::new(BuiltInWindowExpr::new( + create_udwf_window_expr( + &row_number_udwf(), + &[], + &schema, + "row_number() PARTITION BY [a] ORDER BY [b] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), + false, + )?, + &[ + col("a", &schema)? + ], + &LexOrdering::new(vec![ + PhysicalSortExpr::new(col("b", &schema)?, SortOptions::new(true, true)), + ]), + Arc::new(WindowFrame::new(None)), + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + + roundtrip_test(Arc::new(BoundedWindowAggExec::try_new( + vec![udwf_expr], + input, + vec![col("a", &schema)?], + InputOrderMode::Sorted, + )?)) +} + #[test] fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let window_frame = WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::Int64(None)), + WindowFrameBound::CurrentRow, + ); + + let nth_value_window = + create_udwf_window_expr( + &nth_value_udwf(), + &[col("a", &schema)?, + lit(2)], schema.as_ref(), + "NTH_VALUE(a, 2) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), + false, + )?; + let udwf_expr = Arc::new(BuiltInWindowExpr::new( + nth_value_window, + &[col("b", &schema)?], + &LexOrdering { + inner: vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + }, + Arc::new(window_frame), + )); + let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( AggregateExprBuilder::new( avg_udaf(), @@ -306,7 +378,7 @@ fn roundtrip_window() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); roundtrip_test(Arc::new(WindowAggExec::try_new( - vec![plain_aggr_window_expr, sliding_aggr_window_expr], + vec![plain_aggr_window_expr, sliding_aggr_window_expr, udwf_expr], input, vec![col("b", &schema)?], )?)) @@ -948,6 +1020,33 @@ impl PhysicalExtensionCodec for UDFExtensionCodec { } Ok(()) } + + fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "custom_udwf" { + let proto = CustomUDWFNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode custom_udwf: {err}")) + })?; + + Ok(Arc::new(WindowUDF::from(CustomUDWF::new(proto.payload)))) + } else { + not_impl_err!( + "unrecognized user-defined window function implementation, cannot decode" + ) + } + } + + fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udwf) = binding.as_any().downcast_ref::() { + let proto = CustomUDWFNode { + payload: udwf.payload.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udwf: {err:?}")) + })?; + } + Ok(()) + } } #[test] @@ -1005,6 +1104,55 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_udwf_extension_codec() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let custom_udwf = Arc::new(WindowUDF::from(CustomUDWF::new("payload".to_string()))); + let udwf = create_udwf_window_expr( + &custom_udwf, + &[col("a", &schema)?], + schema.as_ref(), + "custom_udwf(a) PARTITION BY [b] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW".to_string(), + false, + )?; + + let window_frame = WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::Int64(None)), + WindowFrameBound::CurrentRow, + ); + + let udwf_expr = Arc::new(BuiltInWindowExpr::new( + udwf, + &[col("b", &schema)?], + &LexOrdering { + inner: vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + }, + Arc::new(window_frame), + )); + + let input = Arc::new(EmptyExec::new(schema.clone())); + let window = Arc::new(BoundedWindowAggExec::try_new( + vec![udwf_expr], + input, + vec![col("b", &schema)?], + InputOrderMode::Sorted, + )?); + + let ctx = SessionContext::new(); + roundtrip_test_and_return(window, &ctx, &UDFExtensionCodec)?; + Ok(()) +} + #[test] fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { let field_text = Field::new("text", DataType::Utf8, true);