diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 9e2ca987f..f36b41aa6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -258,48 +258,44 @@ abstract class CometNativeExec extends CometExec { // If the first non broadcast plan is found, we need to adjust the partition number of // the broadcast plans to make sure they have the same partition number as the first non // broadcast plan. - val firstNonBroadcastPlanNumPartitions = - firstNonBroadcastPlan.map(_._1.outputPartitioning.numPartitions) + val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) = + firstNonBroadcastPlan.get._1 match { + case plan: CometNativeExec => + (null, plan.outputPartitioning.numPartitions) + case plan => + val rdd = plan.executeColumnar() + (rdd, rdd.getNumPartitions) + } // Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule Broadcast RDDs with // same partition number. But for Comet, we need to zip them so we need to adjust the // partition number of Broadcast RDDs to make sure they have the same partition number. - sparkPlans.zipWithIndex.foreach { case (plan, _) => + sparkPlans.zipWithIndex.foreach { case (plan, idx) => plan match { - case c: CometBroadcastExchangeExec if firstNonBroadcastPlanNumPartitions.nonEmpty => - inputs += c - .setNumPartitions(firstNonBroadcastPlanNumPartitions.get) - .executeColumnar() - case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) - if firstNonBroadcastPlanNumPartitions.nonEmpty => - inputs += c - .setNumPartitions(firstNonBroadcastPlanNumPartitions.get) - .executeColumnar() - case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) - if firstNonBroadcastPlanNumPartitions.nonEmpty => - inputs += c - .setNumPartitions(firstNonBroadcastPlanNumPartitions.get) - .executeColumnar() + case c: CometBroadcastExchangeExec => + inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => + inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() + case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => + inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() case BroadcastQueryStageExec( _, ReusedExchangeExec(_, c: CometBroadcastExchangeExec), - _) if firstNonBroadcastPlanNumPartitions.nonEmpty => - inputs += c - .setNumPartitions(firstNonBroadcastPlanNumPartitions.get) - .executeColumnar() + _) => + inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar() case _: CometNativeExec => // no-op - case _ if firstNonBroadcastPlanNumPartitions.nonEmpty => + case _ if idx == firstNonBroadcastPlan.get._2 => + inputs += firstNonBroadcastPlanRDD + case _ => val rdd = plan.executeColumnar() - if (plan.outputPartitioning.numPartitions != firstNonBroadcastPlanNumPartitions.get) { + if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) { throw new CometRuntimeException( s"Partition number mismatch: ${rdd.getNumPartitions} != " + - s"${firstNonBroadcastPlanNumPartitions.get}") + s"$firstNonBroadcastPlanNumPartitions") } else { inputs += rdd } - case _ => - throw new CometRuntimeException(s"Unexpected plan: $plan") } } @@ -310,7 +306,7 @@ abstract class CometNativeExec extends CometExec { if (inputs.nonEmpty) { ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter) } else { - val partitionNum = firstNonBroadcastPlanNumPartitions.get + val partitionNum = firstNonBroadcastPlanNumPartitions CometExecRDD(sparkContext, partitionNum)(createCometExecIter) } } @@ -648,6 +644,7 @@ case class CometUnionExec( override val output: Seq[Attribute], children: Seq[SparkPlan]) extends CometExec { + override def doExecuteColumnar(): RDD[ColumnarBatch] = { sparkContext.union(children.map(_.executeColumnar())) }