From 383ae9197422da5948a31caf21722569ef2067fe Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 20 May 2024 08:31:07 -0600 Subject: [PATCH] chore: improve fallback message when comet native shuffle is not enabled (#445) * improve fallback message when comet native shuffle is not enabled * update test --- .../comet/CometSparkSessionExtensions.scala | 32 +++++++++++++++---- .../apache/comet/CometExpressionSuite.scala | 6 ++-- .../org/apache/spark/sql/CometTestBase.scala | 5 ++- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 7c269c411..85a19f55c 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -30,8 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.comet._ -import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} -import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec, CometShuffleManager} import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} @@ -46,7 +45,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{createMessage, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, shouldApplyRowToColumnar, withInfo, withInfos} +import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, shouldApplyRowToColumnar, withInfo, withInfos} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -684,7 +683,8 @@ class CometSparkSessionExtensions case s: ShuffleExchangeExec => val isShuffleEnabled = isCometShuffleEnabled(conf) - val msg1 = createMessage(!isShuffleEnabled, "Native shuffle is not enabled") + val reason = getCometShuffleNotEnabledReason(conf).getOrElse("no reason available") + val msg1 = createMessage(!isShuffleEnabled, s"Native shuffle is not enabled: $reason") val columnarShuffleEnabled = isCometColumnarShuffleEnabled(conf) val msg2 = createMessage( isShuffleEnabled && !columnarShuffleEnabled && !QueryPlanSerde @@ -933,13 +933,31 @@ object CometSparkSessionExtensions extends Logging { } private[comet] def isCometShuffleEnabled(conf: SQLConf): Boolean = - COMET_EXEC_SHUFFLE_ENABLED.get(conf) && - (conf.contains("spark.shuffle.manager") && conf.getConfString("spark.shuffle.manager") == - "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager") && + COMET_EXEC_SHUFFLE_ENABLED.get(conf) && isCometShuffleManagerEnabled(conf) && // TODO: AQE coalesce partitions feature causes Comet shuffle memory leak. // We should disable Comet shuffle when AQE coalesce partitions is enabled. (!conf.coalesceShufflePartitionsEnabled || COMET_SHUFFLE_ENFORCE_MODE_ENABLED.get()) + private[comet] def getCometShuffleNotEnabledReason(conf: SQLConf): Option[String] = { + if (!COMET_EXEC_SHUFFLE_ENABLED.get(conf)) { + Some(s"${COMET_EXEC_SHUFFLE_ENABLED.key} is not enabled") + } else if (!isCometShuffleManagerEnabled(conf)) { + Some(s"spark.shuffle.manager is not set to ${CometShuffleManager.getClass.getName}") + } else if (conf.coalesceShufflePartitionsEnabled && !COMET_SHUFFLE_ENFORCE_MODE_ENABLED + .get()) { + Some( + s"${SQLConf.COALESCE_PARTITIONS_ENABLED.key} is enabled and " + + s"${COMET_SHUFFLE_ENFORCE_MODE_ENABLED.key} is not enabled") + } else { + None + } + } + + private def isCometShuffleManagerEnabled(conf: SQLConf) = { + conf.contains("spark.shuffle.manager") && conf.getConfString("spark.shuffle.manager") == + "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager" + } + private[comet] def isCometScanEnabled(conf: SQLConf): Boolean = { COMET_SCAN_ENABLED.get(conf) } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f3fd50e9e..98a2bad02 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1420,14 +1420,16 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { "extractintervalmonths is not supported")), ( s"SELECT sum(c0), sum(c2) from $table group by c1", - Set("Native shuffle is not enabled", "AQEShuffleRead is not supported")), + Set( + "Native shuffle is not enabled: spark.comet.exec.shuffle.enabled is not enabled", + "AQEShuffleRead is not supported")), ( "SELECT A.c1, A.sum_c0, A.sum_c2, B.casted from " + s"(SELECT c1, sum(c0) as sum_c0, sum(c2) as sum_c2 from $table group by c1) as A, " + s"(SELECT c1, cast(make_interval(c0, c1, c0, c1, c0, c0, c2) as string) as casted from $table) as B " + "where A.c1 = B.c1 ", Set( - "Native shuffle is not enabled", + "Native shuffle is not enabled: spark.comet.exec.shuffle.enabled is not enabled", "AQEShuffleRead is not supported", "make_interval is not supported", "BroadcastExchange is not supported", diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 112d35b13..0530d764c 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -261,7 +261,10 @@ abstract class CometTestBase } val extendedInfo = new ExtendedExplainInfo().generateExtendedInfo(dfComet.queryExecution.executedPlan) - assert(extendedInfo.equalsIgnoreCase(expectedInfo.toSeq.sorted.mkString("\n"))) + val expectedStr = expectedInfo.toSeq.sorted.mkString("\n") + if (!extendedInfo.equalsIgnoreCase(expectedStr)) { + fail(s"$extendedInfo != $expectedStr (case-insensitive comparison)") + } } private var _spark: SparkSession = _