diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 3ce94c129..ff5b38697 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -767,19 +767,26 @@ impl PhysicalPlanner { Ok(Arc::new(case_expr)) } ExprStruct::ArrayJoin(expr) => { - let src_array_expr = - self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; - let key_expr = - self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; - let args = vec![Arc::clone(&src_array_expr), key_expr]; + let array_expr = + self.create_expr(expr.array_expr.as_ref().unwrap(), Arc::clone(&input_schema))?; + let delimiter_expr = + self.create_expr(expr.delimiter_expr.as_ref().unwrap(), Arc::clone(&input_schema))?; + + let mut args = vec![Arc::clone(&array_expr), delimiter_expr]; + if expr.null_replacement_expr.is_some() { + let null_replacement_expr = + self.create_expr(expr.null_replacement_expr.as_ref().unwrap(), Arc::clone(&input_schema))?; + args.push(null_replacement_expr) + } + let datafusion_array_to_string = array_to_string_udf(); - let array_intersect_expr = Arc::new(ScalarFunctionExpr::new( + let array_join_expr = Arc::new(ScalarFunctionExpr::new( "array_join", datafusion_array_to_string, args, DataType::Utf8, )); - Ok(array_intersect_expr) + Ok(array_join_expr) } expr => Err(ExecutionError::GeneralError(format!( "Not implemented: {:?}", diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 4a9339fb8..271354c82 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -86,7 +86,7 @@ message Expr { ArrayInsert array_insert = 59; BinaryExpr array_contains = 60; BinaryExpr array_remove = 61; - BinaryExpr array_join = 63; + ArrayJoin array_join = 62; } } @@ -415,6 +415,12 @@ message ArrayInsert { bool legacy_negative_index = 4; } +message ArrayJoin { + Expr array_expr = 1; + Expr delimiter_expr = 2; + Expr null_replacement_expr = 3; +} + message DataType { enum DataTypeId { BOOL = 0; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 11242cbdb..022787666 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2284,12 +2284,38 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim expr.children(1), inputs, (builder, binaryExpr) => builder.setArrayAppend(binaryExpr)) - case _ if expr.prettyName == "array_join" => - createBinaryExpr( - expr.children(0), - expr.children(1), - inputs, - (builder, binaryExpr) => builder.setArrayJoin(binaryExpr)) + case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) => + val arrayExprProto = exprToProto(arrayExpr, inputs, binding) + val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding) + + if (arrayExprProto.isDefined && delimiterExprProto.isDefined) { + val arrayJoinBuilder = nullReplacementExpr match { + case Some(nrExpr) => + val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding) + ExprOuterClass.ArrayJoin + .newBuilder() + .setArrayExpr(arrayExprProto.get) + .setDelimiterExpr(delimiterExprProto.get) + .setNullReplacementExpr(nullReplacementExprProto.get) + case None => + ExprOuterClass.ArrayJoin + .newBuilder() + .setArrayExpr(arrayExprProto.get) + .setDelimiterExpr(delimiterExprProto.get) + } + Some( + ExprOuterClass.Expr + .newBuilder() + .setArrayJoin(arrayJoinBuilder) + .build()) + } else { + val exprs: List[Expression] = nullReplacementExpr match { + case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr) + case None => List(arrayExpr, delimiterExpr) + } + withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*) + None + } case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 78e24c7d7..16a936a5c 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2554,6 +2554,8 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { spark.read.parquet(path.toString).createOrReplaceTempView("t1") checkSparkAnswerAndOperator(sql( "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ') from t1")) + checkSparkAnswerAndOperator(sql( + "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ', ' +++ ') from t1")) checkSparkAnswerAndOperator(sql( "SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') from t1 where _2 is not null")) checkSparkAnswerAndOperator(