Skip to content

Commit

Permalink
Support collect_set aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Mar 6, 2024
1 parent c4ae0f8 commit 0e33012
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -480,4 +480,6 @@ object BackendSettings extends BackendSettingsApi {
// vanilla Spark, we need to rewrite the aggregate to get the correct data type.
true
}

override def shouldRewriteCollectSet(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,46 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}
}

test("test collect_set") {
runQueryAndCompare("SELECT array_sort(collect_set(l_partkey)) FROM lineitem") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 2)
}
}

runQueryAndCompare(
"""
|SELECT array_sort(collect_set(l_suppkey)), array_sort(collect_set(l_partkey))
|FROM lineitem
|""".stripMargin) {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 2)
}
}

runQueryAndCompare(
"SELECT count(distinct l_suppkey), array_sort(collect_set(l_partkey)) FROM lineitem") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}
}

test("count(1)") {
runQueryAndCompare(
"""
Expand Down
10 changes: 1 addition & 9 deletions cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,20 +375,12 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"starts_with", "startswith"},
{"named_struct", "row_constructor"},
{"bit_or", "bitwise_or_agg"},
{"bit_or_partial", "bitwise_or_agg_partial"},
{"bit_or_merge", "bitwise_or_agg_merge"},
{"bit_and", "bitwise_and_agg"},
{"bit_and_partial", "bitwise_and_agg_partial"},
{"bit_and_merge", "bitwise_and_agg_merge"},
{"murmur3hash", "hash_with_seed"},
{"modulus", "remainder"},
{"date_format", "format_datetime"},
{"collect_set", "set_agg"},
{"collect_set_partial", "set_agg_partial"},
{"collect_set_merge", "set_agg_merge"},
{"collect_list", "array_agg"},
{"collect_list_partial", "array_agg_partial"},
{"collect_list_merge", "array_agg_merge"}};
{"collect_list", "array_agg"}};

const std::unordered_map<std::string, std::string> SubstraitParser::typeMap_ = {
{"bool", "BOOLEAN"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,6 @@ trait BackendSettingsApi {
def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false

def shouldRewriteTypedImperativeAggregate(): Boolean = false

def shouldRewriteCollectSet(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ object ColumnarOverrideRules {
val rewriteRules = Seq(
RewriteIn,
RewriteMultiChildrenCount,
RewriteCollectSet,
RewriteTypedImperativeAggregate,
PullOutPreProject,
PullOutPostProject)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.extension

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.utils.PullOutProjectHelper

import org.apache.spark.sql.catalyst.expressions.{And, IsNotNull}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectSet, Complete, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec

/**
* This rule Add `IsNotNull` to skip null value before going to native collect_set
*
* TODO: remove this rule once Velox collect_set skip null value
*/
object RewriteCollectSet extends Rule[SparkPlan] with PullOutProjectHelper {
private lazy val shouldRewriteCollect =
BackendsApiManager.getSettings.shouldRewriteCollectSet()

private def shouldRewrite(ae: AggregateExpression): Boolean = {
ae.aggregateFunction match {
case _: CollectSet =>
ae.mode match {
case Partial | Complete => true
case _ => false
}
case _ => false
}
}

private def rewriteCollectFilter(aggExprs: Seq[AggregateExpression]): Seq[AggregateExpression] = {
aggExprs
.map {
aggExpr =>
if (shouldRewrite(aggExpr)) {
val newFilter =
(aggExpr.filter ++ Seq(IsNotNull(aggExpr.aggregateFunction.children.head)))
.reduce(And)
aggExpr.copy(filter = Option(newFilter))
} else {
aggExpr
}
}
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!shouldRewriteCollect) {
return plan
}

plan match {
case agg: BaseAggregateExec if agg.aggregateExpressions.exists(shouldRewrite) =>
val newAggExprs = rewriteCollectFilter(agg.aggregateExpressions)
copyBaseAggregateExec(agg)(newAggregateExpressions = newAggExprs)

case _ => plan
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ class VeloxTestSettings extends BackendTestSettings {
"SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it",
// TODO: fix inconsistent behavior.
"SPARK-17641: collect functions should not collect null values"
" before using it"
)

enableSuite[GlutenCastSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,7 @@ class VeloxTestSettings extends BackendTestSettings {
"SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it",
// TODO: fix inconsistent behavior.
"SPARK-17641: collect functions should not collect null values"
" before using it"
)
enableSuite[GlutenDataFrameAsOfJoinSuite]
enableSuite[GlutenDataFrameComplexTypeSuite]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -938,9 +938,7 @@ class VeloxTestSettings extends BackendTestSettings {
"SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate",
// Replaced with another test.
"SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it",
// TODO: fix inconsistent behavior.
"SPARK-17641: collect functions should not collect null values"
" before using it"
)
enableSuite[GlutenDataFrameAsOfJoinSuite]
enableSuite[GlutenDataFrameComplexTypeSuite]
Expand Down

0 comments on commit 0e33012

Please sign in to comment.