diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 2e64403c64..0824235718 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -45,6 +45,9 @@ import org.apache.comet.shims.ShimCometConf */ object CometConf extends ShimCometConf { + val COMPAT_GUIDE: String = "For more information, refer to the Comet Compatibility " + + "Guide (https://datafusion.apache.org/comet/user-guide/compatibility.html)" + private val TUNING_GUIDE = "For more information, refer to the Comet Tuning " + "Guide (https://datafusion.apache.org/comet/user-guide/tuning.html)" @@ -605,20 +608,27 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_EXPR_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] = + conf("spark.comet.expression.allowIncompatible") + .doc( + "Comet is not currently fully compatible with Spark for all expressions. " + + s"Set this config to true to allow them anyway. $COMPAT_GUIDE.") + .booleanConf + .createWithDefault(false) + val COMET_CAST_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] = conf("spark.comet.cast.allowIncompatible") .doc( "Comet is not currently fully compatible with Spark for all cast operations. " + - "Set this config to true to allow them anyway. See compatibility guide " + - "for more information.") + s"Set this config to true to allow them anyway. $COMPAT_GUIDE.") .booleanConf .createWithDefault(false) val COMET_REGEXP_ALLOW_INCOMPATIBLE: ConfigEntry[Boolean] = conf("spark.comet.regexp.allowIncompatible") - .doc("Comet is not currently fully compatible with Spark for all regular expressions. " + - "Set this config to true to allow them anyway using Rust's regular expression engine. " + - "See compatibility guide for more information.") + .doc( + "Comet is not currently fully compatible with Spark for all regular expressions. " + + s"Set this config to true to allow them anyway. $COMPAT_GUIDE.") .booleanConf .createWithDefault(false) diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 2bea501e5c..8245e7b76b 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -25,7 +25,7 @@ Comet provides the following configuration settings. |--------|-------------|---------------| | spark.comet.batchSize | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 | | spark.comet.caseConversion.enabled | Java uses locale-specific rules when converting strings to upper or lower case and Rust does not, so we disable upper and lower by default. | false | -| spark.comet.cast.allowIncompatible | Comet is not currently fully compatible with Spark for all cast operations. Set this config to true to allow them anyway. See compatibility guide for more information. | false | +| spark.comet.cast.allowIncompatible | Comet is not currently fully compatible with Spark for all cast operations. Set this config to true to allow them anyway. For more information, refer to the Comet Compatibility Guide (https://datafusion.apache.org/comet/user-guide/compatibility.html). | false | | spark.comet.columnar.shuffle.async.enabled | Whether to enable asynchronous shuffle for Arrow-based shuffle. | false | | spark.comet.columnar.shuffle.async.max.thread.num | Maximum number of threads on an executor used for Comet async columnar shuffle. This is the upper bound of total number of shuffle threads per executor. In other words, if the number of cores * the number of shuffle threads per task `spark.comet.columnar.shuffle.async.thread.num` is larger than this config. Comet will use this config as the number of shuffle threads per executor instead. | 100 | | spark.comet.columnar.shuffle.async.thread.num | Number of threads used for Comet async columnar shuffle per shuffle task. Note that more threads means more memory requirement to buffer shuffle data before flushing to disk. Also, more threads may not always improve performance, and should be set based on the number of cores available. | 3 | @@ -64,6 +64,7 @@ Comet provides the following configuration settings. | spark.comet.explain.native.enabled | When this setting is enabled, Comet will provide a tree representation of the native query plan before execution and again after execution, with metrics. | false | | spark.comet.explain.verbose.enabled | When this setting is enabled, Comet will provide a verbose tree representation of the extended information. | false | | spark.comet.explainFallback.enabled | When this setting is enabled, Comet will provide logging explaining the reason(s) why a query stage cannot be executed natively. Set this to false to reduce the amount of logging. | false | +| spark.comet.expression.allowIncompatible | Comet is not currently fully compatible with Spark for all expressions. Set this config to true to allow them anyway. For more information, refer to the Comet Compatibility Guide (https://datafusion.apache.org/comet/user-guide/compatibility.html). | false | | spark.comet.memory.overhead.factor | Fraction of executor memory to be allocated as additional non-heap memory per executor process for Comet. | 0.2 | | spark.comet.memory.overhead.min | Minimum amount of additional memory to be allocated per executor process for Comet, in MiB. | 402653184b | | spark.comet.nativeLoadRequired | Whether to require Comet native library to load successfully when Comet is enabled. If not, Comet will silently fallback to Spark when it fails to load the native lib. Otherwise, an error will be thrown and the Spark job will be aborted. | false | @@ -73,7 +74,7 @@ Comet provides the following configuration settings. | spark.comet.parquet.read.io.mergeRanges.delta | The delta in bytes between consecutive read ranges below which the parallel reader will try to merge the ranges. The default is 8MB. | 8388608 | | spark.comet.parquet.read.parallel.io.enabled | Whether to enable Comet's parallel reader for Parquet files. The parallel reader reads ranges of consecutive data in a file in parallel. It is faster for large files and row groups but uses more resources. | true | | spark.comet.parquet.read.parallel.io.thread-pool.size | The maximum number of parallel threads the parallel reader will use in a single executor. For executors configured with a smaller number of cores, use a smaller number. | 16 | -| spark.comet.regexp.allowIncompatible | Comet is not currently fully compatible with Spark for all regular expressions. Set this config to true to allow them anyway using Rust's regular expression engine. See compatibility guide for more information. | false | +| spark.comet.regexp.allowIncompatible | Comet is not currently fully compatible with Spark for all regular expressions. Set this config to true to allow them anyway. For more information, refer to the Comet Compatibility Guide (https://datafusion.apache.org/comet/user-guide/compatibility.html). | false | | spark.comet.scan.enabled | Whether to enable native scans. When this is turned on, Spark will use Comet to read supported data sources (currently only Parquet is supported natively). Note that to enable native vectorized execution, both this config and 'spark.comet.exec.enabled' need to be enabled. | true | | spark.comet.scan.preFetch.enabled | Whether to enable pre-fetching feature of CometScan. | false | | spark.comet.scan.preFetch.threadNum | The number of threads running pre-fetching for CometScan. Effective if spark.comet.scan.preFetch.enabled is enabled. Note that more pre-fetching threads means more memory requirement to store pre-fetched row groups. | 2 | diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 99490f4406..6125aa074b 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -182,15 +182,26 @@ The following Spark expressions are currently available. Any known compatibility | VariancePop | | | VarianceSamp | | -## Complex Types - -| Expression | Notes | -| ----------------- | ----------- | -| CreateNamedStruct | | -| ElementAt | Arrays only | -| GetArrayItem | | -| GetStructField | | -| StructsToJson | | +## Arrays + +| Expression | Notes | +|-------------------|--------------| +| ArrayAppend | Experimental | +| ArrayContains | Experimental | +| ArrayIntersect | Experimental | +| ArrayJoin | Experimental | +| ArrayRemove | Experimental | +| ArraysOverlap | Experimental | +| ElementAt | Arrays only | +| GetArrayItem | | + +## Structs + +| Expression | Notes | +|-------------------|--------------| +| CreateNamedStruct | | +| GetStructField | | +| StructsToJson | | ## Other diff --git a/docs/templates/compatibility-template.md b/docs/templates/compatibility-template.md index f6a725ac65..c63876e904 100644 --- a/docs/templates/compatibility-template.md +++ b/docs/templates/compatibility-template.md @@ -32,12 +32,6 @@ be used in production. There is an [epic](https://github.com/apache/datafusion-comet/issues/313) where we are tracking the work to fully implement ANSI support. -## Regular Expressions - -Comet uses the Rust regexp crate for evaluating regular expressions, and this has different behavior from Java's -regular expression engine. Comet will fall back to Spark for patterns that are known to produce different results, but -this can be overridden by setting `spark.comet.regexp.allowIncompatible=true`. - ## Floating number comparison Spark normalizes NaN and zero for floating point numbers for several cases. See `NormalizeFloatingNumbers` optimization rule in Spark. @@ -46,6 +40,22 @@ because they are handled well in Spark (e.g., `SQLOrderingUtil.compareFloats`). functions of arrow-rs used by DataFusion do not normalize NaN and zero (e.g., [arrow::compute::kernels::cmp::eq](https://docs.rs/arrow/latest/arrow/compute/kernels/cmp/fn.eq.html#)). So Comet will add additional normalization expression of NaN and zero for comparison. +## Incompatible Expressions + +Some Comet native expressions are not 100% compatible with Spark and are disabled by default. These expressions +will fall back to Spark but can be enabled by setting `spark.comet.expression.allowIncompatible=true`. + +## Array Expressions + +Comet has experimental support for a number of array expressions. These are experimental and currently marked +as incompatible and can be enabled by setting `spark.comet.expression.allowIncompatible=true`. + +## Regular Expressions + +Comet uses the Rust regexp crate for evaluating regular expressions, and this has different behavior from Java's +regular expression engine. Comet will fall back to Spark for patterns that are known to produce different results, but +this can be overridden by setting `spark.comet.regexp.allowIncompatible=true`. + ## Cast Cast operations in Comet fall into three levels of support: diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cb4fffc1a3..7b375bc23f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -929,6 +929,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim binding: Boolean): Option[Expr] = { SQLConf.get + def convert(handler: CometExpressionSerde): Option[Expr] = { + handler match { + case _: IncompatExpr if !CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.get() => + withInfo( + expr, + s"$expr is not fully compatible with Spark. To enable it anyway, set " + + s"${CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key}=true. ${CometConf.COMPAT_GUIDE}.") + None + case _ => + handler.convert(expr, inputs, binding) + } + } + expr match { case a @ Alias(_, _) => val r = exprToProtoInternal(a.child, inputs, binding) @@ -2371,83 +2384,19 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim withInfo(expr, "unsupported arguments for GetArrayStructFields", child) None } - case expr: ArrayRemove => CometArrayRemove.convert(expr, inputs, binding) - case expr if expr.prettyName == "array_contains" => - createBinaryExpr( - expr, - expr.children(0), - expr.children(1), - inputs, - binding, - (builder, binaryExpr) => builder.setArrayContains(binaryExpr)) - case _ if expr.prettyName == "array_append" => - createBinaryExpr( - expr, - expr.children(0), - expr.children(1), - inputs, - binding, - (builder, binaryExpr) => builder.setArrayAppend(binaryExpr)) - case _ if expr.prettyName == "array_intersect" => - createBinaryExpr( - expr, - expr.children(0), - expr.children(1), - inputs, - binding, - (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr)) - case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) => - val arrayExprProto = exprToProto(arrayExpr, inputs, binding) - val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding) - - if (arrayExprProto.isDefined && delimiterExprProto.isDefined) { - val arrayJoinBuilder = nullReplacementExpr match { - case Some(nrExpr) => - val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding) - ExprOuterClass.ArrayJoin - .newBuilder() - .setArrayExpr(arrayExprProto.get) - .setDelimiterExpr(delimiterExprProto.get) - .setNullReplacementExpr(nullReplacementExprProto.get) - case None => - ExprOuterClass.ArrayJoin - .newBuilder() - .setArrayExpr(arrayExprProto.get) - .setDelimiterExpr(delimiterExprProto.get) - } - Some( - ExprOuterClass.Expr - .newBuilder() - .setArrayJoin(arrayJoinBuilder) - .build()) - } else { - val exprs: List[Expression] = nullReplacementExpr match { - case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr) - case None => List(arrayExpr, delimiterExpr) - } - withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*) - None - } - case ArraysOverlap(leftArrayExpr, rightArrayExpr) => - if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { - createBinaryExpr( - expr, - leftArrayExpr, - rightArrayExpr, - inputs, - binding, - (builder, binaryExpr) => builder.setArraysOverlap(binaryExpr)) - } else { - withInfo( - expr, - s"$expr is not supported yet. To enable all incompatible casts, set " + - s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true") - None - } + case _: ArrayRemove => convert(CometArrayRemove) + case _: ArrayContains => convert(CometArrayContains) + // Function introduced in 3.4.0. Refer by name to provide compatibility + // with older Spark builds + case _ if expr.prettyName == "array_append" => convert(CometArrayAppend) + case _: ArrayIntersect => convert(CometArrayIntersect) + case _: ArrayJoin => convert(CometArrayJoin) + case _: ArraysOverlap => convert(CometArraysOverlap) case _ => withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None } + } /** @@ -3490,3 +3439,6 @@ trait CometExpressionSerde { inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] } + +/** Marker trait for an expression that is not guaranteed to be 100% compatible with Spark */ +trait IncompatExpr {} diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 9058a641ee..db1679f22b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -19,11 +19,11 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{ArrayRemove, Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, ArrayRemove, Attribute, Expression} import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, StructType} import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.createBinaryExpr +import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProto} import org.apache.comet.shims.CometExprShim object CometArrayRemove extends CometExpressionSerde with CometExprShim { @@ -65,3 +65,103 @@ object CometArrayRemove extends CometExpressionSerde with CometExprShim { (builder, binaryExpr) => builder.setArrayRemove(binaryExpr)) } } + +object CometArrayAppend extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArrayAppend(binaryExpr)) + } +} + +object CometArrayContains extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArrayContains(binaryExpr)) + } +} + +object CometArrayIntersect extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr)) + } +} + +object CometArraysOverlap extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArraysOverlap(binaryExpr)) + } +} + +object CometArrayJoin extends CometExpressionSerde with IncompatExpr { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayExpr = expr.asInstanceOf[ArrayJoin] + val arrayExprProto = exprToProto(arrayExpr.array, inputs, binding) + val delimiterExprProto = exprToProto(arrayExpr.delimiter, inputs, binding) + + if (arrayExprProto.isDefined && delimiterExprProto.isDefined) { + val arrayJoinBuilder = arrayExpr.nullReplacement match { + case Some(nrExpr) => + val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding) + ExprOuterClass.ArrayJoin + .newBuilder() + .setArrayExpr(arrayExprProto.get) + .setDelimiterExpr(delimiterExprProto.get) + .setNullReplacementExpr(nullReplacementExprProto.get) + case None => + ExprOuterClass.ArrayJoin + .newBuilder() + .setArrayExpr(arrayExprProto.get) + .setDelimiterExpr(delimiterExprProto.get) + } + Some( + ExprOuterClass.Expr + .newBuilder() + .setArrayJoin(arrayJoinBuilder) + .build()) + } else { + val exprs: List[Expression] = arrayExpr.nullReplacement match { + case Some(nrExpr) => List(arrayExpr, arrayExpr.delimiter, nrExpr) + case None => List(arrayExpr, arrayExpr.delimiter) + } + withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*) + None + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 5727f9f907..df1fccb698 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -25,8 +25,10 @@ import scala.util.Random import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.{array, col, expr, lit, udf} import org.apache.spark.sql.types.StructType +import org.apache.comet.CometSparkSessionExtensions.{isSpark34Plus, isSpark35Plus} import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { @@ -131,4 +133,163 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp checkExplainString = false) } } + + test("array_append") { + assume(isSpark34Plus) + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator(spark.sql("Select array_append(array(_1),false) from t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append(array(_8), 'test') FROM t1")); + checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_19), _19) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } + } + } + } + + test("array_prepend") { + assume(isSpark35Plus) // in Spark 3.5 array_prepend is implemented via array_insert + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator(spark.sql("Select array_prepend(array(_1),false) from t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1")); + checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1")); + checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_19), _19) FROM t1")); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } + } + } + + test("ArrayInsert") { + assume(isSpark34Plus) + Seq(true, false).foreach(dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + val df = spark.read + .parquet(path.toString) + .withColumn("arr", array(col("_4"), lit(null), col("_4"))) + .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)")) + .withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, -1, 1)")) + .withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 1)")) + .withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, -8, 1)")) + .withColumn("arrInsertNone", expr("array_insert(arr, 1, null)")) + checkSparkAnswerAndOperator(df.select("arrInsertResult")) + checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult")) + checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize")) + checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize")) + checkSparkAnswerAndOperator(df.select("arrInsertNone")) + }) + } + + test("ArrayInsertUnsupportedArgs") { + // This test checks that the else branch in ArrayInsert + // mapping to the comet is valid and fallback to spark is working fine. + assume(isSpark34Plus) + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = false, 10000) + val df = spark.read + .parquet(path.toString) + .withColumn("arr", array(col("_4"), lit(null), col("_4"))) + .withColumn("idx", udf((_: Int) => 1).apply(col("_4"))) + .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) + checkSparkAnswer(df.select("arrUnsupportedArgs")) + } + } + + test("array_contains") { + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = false, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } + } + } + + test("array_intersect") { + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4)) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_2 * -1), array(_9, _10)) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_intersect(array(_18), array(_19)) from t1")) + } + } + } + } + + test("array_join") { + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql( + "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ') from t1")) + checkSparkAnswerAndOperator(sql( + "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ', ' +++ ') from t1")) + checkSparkAnswerAndOperator(sql( + "SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') from t1 where _2 is not null")) + checkSparkAnswerAndOperator( + sql( + "SELECT array_join(array('hello', '-', 'world', cast(_2 as string)), ' ') from t1")) + } + } + } + } + + test("arrays_overlap") { + withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null")) + checkSparkAnswerAndOperator(sql( + "SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null")) + checkSparkAnswerAndOperator(spark.sql( + "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1")); + } + } + } + } + } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f82101b3ac..6ffc2993dc 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.sql.types.{Decimal, DecimalType} -import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark35Plus, isSpark40Plus} +import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark40Plus} class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ @@ -2575,151 +2575,4 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("array_append") { - assume(isSpark34Plus) - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); - checkSparkAnswerAndOperator(spark.sql("Select array_append(array(_1),false) from t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_append(array(_2, _3, _4), 4) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_append(array(_2, _3, _4), null) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_append(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1")); - checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_8), 'test') FROM t1")); - checkSparkAnswerAndOperator(spark.sql("SELECT array_append(array(_19), _19) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_append((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); - } - } - } - - test("array_prepend") { - assume(isSpark35Plus) // in Spark 3.5 array_prepend is implemented via array_insert - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); - checkSparkAnswerAndOperator(spark.sql("Select array_prepend(array(_1),false) from t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_prepend(array(_2, _3, _4), 4) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_prepend(array(_2, _3, _4), null) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_prepend(array(_6, _7), CAST(6.5 AS DOUBLE)) FROM t1")); - checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_8), 'test') FROM t1")); - checkSparkAnswerAndOperator(spark.sql("SELECT array_prepend(array(_19), _19) FROM t1")); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_prepend((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); - } - } - } - - test("ArrayInsert") { - assume(isSpark34Plus) - Seq(true, false).foreach(dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled, 10000) - val df = spark.read - .parquet(path.toString) - .withColumn("arr", array(col("_4"), lit(null), col("_4"))) - .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)")) - .withColumn("arrInsertNegativeIndexResult", expr("array_insert(arr, -1, 1)")) - .withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 1)")) - .withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, -8, 1)")) - .withColumn("arrInsertNone", expr("array_insert(arr, 1, null)")) - checkSparkAnswerAndOperator(df.select("arrInsertResult")) - checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult")) - checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize")) - checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize")) - checkSparkAnswerAndOperator(df.select("arrInsertNone")) - }) - } - - test("ArrayInsertUnsupportedArgs") { - // This test checks that the else branch in ArrayInsert - // mapping to the comet is valid and fallback to spark is working fine. - assume(isSpark34Plus) - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = false, 10000) - val df = spark.read - .parquet(path.toString) - .withColumn("arr", array(col("_4"), lit(null), col("_4"))) - .withColumn("idx", udf((_: Int) => 1).apply(col("_4"))) - .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) - checkSparkAnswer(df.select("arrUnsupportedArgs")) - } - } - - test("array_contains") { - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = false, n = 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); - } - } - - test("array_intersect") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator( - sql("SELECT array_intersect(array(_2, _3, _4), array(_3, _4)) from t1")) - checkSparkAnswerAndOperator( - sql("SELECT array_intersect(array(_2 * -1), array(_9, _10)) from t1")) - checkSparkAnswerAndOperator(sql("SELECT array_intersect(array(_18), array(_19)) from t1")) - } - } - } - - test("array_join") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator(sql( - "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ') from t1")) - checkSparkAnswerAndOperator(sql( - "SELECT array_join(array(cast(_1 as string), cast(_2 as string), cast(_6 as string)), ' @ ', ' +++ ') from t1")) - checkSparkAnswerAndOperator(sql( - "SELECT array_join(array('hello', 'world', cast(_2 as string)), ' ') from t1 where _2 is not null")) - checkSparkAnswerAndOperator( - sql("SELECT array_join(array('hello', '-', 'world', cast(_2 as string)), ' ') from t1")) - } - } - } - - test("arrays_overlap") { - withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled, 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1") - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null")) - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null")) - checkSparkAnswerAndOperator(sql( - "SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null")) - checkSparkAnswerAndOperator(spark.sql( - "SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1")); - } - } - } - } - }