diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 56b66909e..b27fa3a75 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1783,6 +1783,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { if (aggregateExpressions.isEmpty) { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + if (resultExprs.exists(_.isEmpty)) { + emitWarning(s"Unsupported result expressions found in: ${resultExpressions}") + return None + } + hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) Some(result.setHashAgg(hashAggBuilder).build()) } else { val modes = aggregateExpressions.map(_.mode).distinct diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index bc645cb6a..8a68a925e 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -40,6 +40,22 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { import testImplicits._ + test("Aggregation without aggregate expressions should use correct result expressions") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test") + makeParquetFile(path, 10000, 10, false) + withParquetTable(path.toUri.toString, "tbl") { + val df = sql("SELECT _g5 FROM tbl GROUP BY _g1, _g2, _g3, _g4, _g5") + checkSparkAnswer(df) + } + } + } + } + test("Final aggregation should not bind to the input of partial aggregation") { withSQLConf( CometConf.COMET_ENABLED.key -> "true",