Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: CometExecRule should handle ShuffleQueryStage and ReusedExchange #186

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -221,12 +222,16 @@ class CometSparkSessionExtensions
*/
// spotless:on
private def transform(plan: SparkPlan): SparkPlan = {
def transform1(op: UnaryExecNode): Option[Operator] = {
op.child match {
case childNativeOp: CometNativeExec =>
QueryPlanSerde.operator2Proto(op, childNativeOp.nativeOp)
case _ =>
None
def transform1(op: SparkPlan): Option[Operator] = {
sunchao marked this conversation as resolved.
Show resolved Hide resolved
val allNativeExec = op.children.map {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why this change is needed here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not UnaryExecNode.

case childNativeOp: CometNativeExec => Some(childNativeOp.nativeOp)
case _ => None
}

if (allNativeExec.forall(_.isDefined)) {
QueryPlanSerde.operator2Proto(op, allNativeExec.map(_.get): _*)
} else {
None
}
}

Expand Down Expand Up @@ -377,6 +382,31 @@ class CometSparkSessionExtensions
case None => b
}

// For AQE shuffle stage on a Comet shuffle exchange
case s @ ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) =>
viirya marked this conversation as resolved.
Show resolved Hide resolved
val newOp = transform1(s)
newOp match {
case Some(nativeOp) =>
CometSinkPlaceHolder(nativeOp, s, s)
case None =>
s
}

// For AQE shuffle stage on a reused Comet shuffle exchange
// Note that we don't need to handle `ReusedExchangeExec` for non-AQE case, because
// the query plan won't be re-optimized/planned in non-AQE mode.
case s @ ShuffleQueryStageExec(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if we can replace this rule with:

        case ReusedExchangeExec(_, op) => op

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then users won't see ReusedExchangeExec anymore in explain string or Spark UI. It is useful to know which part is reused by Spark.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops you are right

_,
ReusedExchangeExec(_, _: CometShuffleExchangeExec),
_) =>
val newOp = transform1(s)
newOp match {
case Some(nativeOp) =>
CometSinkPlaceHolder(nativeOp, s, s)
case None =>
s
}

// Native shuffle for Comet operators
case s: ShuffleExchangeExec
if isCometShuffleEnabled(conf) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1883,6 +1885,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
case _: CollectLimitExec => true
case _: UnionExec => true
case _: ShuffleExchangeExec => true
case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
case _: BroadcastExchangeExec => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{Partitioner, SparkConf}
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.functions.col
Expand Down Expand Up @@ -933,6 +933,52 @@ class CometShuffleSuite extends CometColumnarShuffleSuite {
override protected val asyncShuffleEnable: Boolean = false

protected val adaptiveExecutionEnabled: Boolean = true

import testImplicits._

test("Comet native operator after ShuffleQueryStage") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
val df = sql("SELECT * FROM tbl_a")
val shuffled = df
.select($"_1" + 1 as ("a"))
.filter($"a" > 4)
.repartition(10)
.sortWithinPartitions($"a")
checkSparkAnswerAndOperator(shuffled, classOf[ShuffleQueryStageExec])
}
}
}

test("Comet native operator after ShuffleQueryStage + ReusedExchange") {
withSQLConf(
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
CometConf.COMET_EXEC_ENABLED.key -> "true",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") {
withParquetTable((0 until 10).map(i => (i, i % 5)), "tbl_a") {
withParquetTable((0 until 10).map(i => (i % 10, i + 2)), "tbl_b") {
val df = sql("SELECT * FROM tbl_a")
val left = df
.select($"_1" + 1 as ("a"))
.filter($"a" > 4)
val right = left.select($"a" as ("b"))
val join = left.join(right, $"a" === $"b")
checkSparkAnswerAndOperator(
join,
classOf[ShuffleQueryStageExec],
classOf[SortMergeJoinExec],
classOf[AQEShuffleReadExec])
}
}
}
}
}

class DisableAQECometShuffleSuite extends CometColumnarShuffleSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 1000)
var allTypes: Seq[Int] = (1 to 20)
if (isSpark34Plus) {
allTypes = allTypes.filterNot(Set(14, 17).contains)
if (!isSpark34Plus) {
// TODO: Remove this once after https://github.com/apache/arrow/issues/40038 is fixed
allTypes = allTypes.filterNot(Set(14).contains)
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously these columns fail on Spark 3.4. I just ran test with them and looks okay now.

Comment on lines +67 to 70
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides, after this change, Comet native operator can be after CometExchange. So it triggers the known bug in Java Arrow on column _14. I exclude the column for Spark 3.2 and 3.3.

== Physical Plan ==                                                                                                                                                                                                                    
AdaptiveSparkPlan isFinalPlan=true                     
+- == Final Plan ==                                                                                                
   *(1) ColumnarToRow                                                                                              
   +- CometProject [_1#556], [_1#556]                                                                              
      +- ShuffleQueryStage 0                                                                                                                                                                                                           
         +- CometExchange hashpartitioning(_13#568, 10), REPARTITION_BY_NUM, CometNativeShuffle, [plan_id=838]                                                                                                                         
            +- CometScan parquet [_1#556,_13#568] Batched: true, DataFilters: [], Format: CometParquet, Location: InMemoryFileIndex(1 paths)[file:/Users/liangchi/repos/arrow-datafusion-comet/spark/target/tmp/spa..., PartitionFilter
s: [], PushedFilters: [], ReadSchema: struct<_1:boolean,_13:string>                                       
+- == Initial Plan ==                                                                                              
   CometProject [_1#556], [_1#556]                                                                                 
   +- CometExchange hashpartitioning(_13#568, 10), REPARTITION_BY_NUM, CometNativeShuffle, [plan_id=830]
      +- CometScan parquet [_1#556,_13#568] Batched: true, DataFilters: [], Format: CometParquet, Location: InMemoryFileIndex(1 paths)[file:/Users/liangchi/repos/arrow-datafusion-comet/spark/target/tmp/spa..., PartitionFilters: [],
 PushedFilters: [], ReadSchema: struct<_1:boolean,_13:string>

allTypes.map(i => s"_$i").foreach { c =>
withSQLConf(
Expand Down
Loading