Skip to content

Commit

Permalink
Support collect_set
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Feb 29, 2024
1 parent 908b2f6 commit 98ab12b
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -449,4 +449,6 @@ object BackendSettings extends BackendSettingsApi {
// vanilla Spark, we need to rewrite the aggregate to get the correct data type.
true
}

override def shouldRewriteCollectSet(): Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,46 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu
}
}

test("test collect_set") {
runQueryAndCompare("SELECT array_sort(collect_set(l_partkey)) FROM lineitem") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 2)
}
}

runQueryAndCompare(
"""
|SELECT array_sort(collect_set(l_suppkey)), array_sort(collect_set(l_partkey))
|FROM lineitem
|""".stripMargin) {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 2)
}
}

runQueryAndCompare(
"SELECT count(distinct l_suppkey), array_sort(collect_set(l_partkey)) FROM lineitem") {
df =>
{
assert(
getExecutedPlan(df).count(
plan => {
plan.isInstanceOf[HashAggregateExecTransformer]
}) == 4)
}
}
}

test("count(1)") {
runQueryAndCompare(
"""
Expand Down
10 changes: 1 addition & 9 deletions cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,20 +368,12 @@ std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunc
{"starts_with", "startswith"},
{"named_struct", "row_constructor"},
{"bit_or", "bitwise_or_agg"},
{"bit_or_partial", "bitwise_or_agg_partial"},
{"bit_or_merge", "bitwise_or_agg_merge"},
{"bit_and", "bitwise_and_agg"},
{"bit_and_partial", "bitwise_and_agg_partial"},
{"bit_and_merge", "bitwise_and_agg_merge"},
{"murmur3hash", "hash_with_seed"},
{"modulus", "remainder"},
{"date_format", "format_datetime"},
{"collect_set", "set_agg"},
{"collect_set_partial", "set_agg_partial"},
{"collect_set_merge", "set_agg_merge"},
{"collect_list", "array_agg"},
{"collect_list_partial", "array_agg_partial"},
{"collect_list_merge", "array_agg_merge"}};
{"collect_list", "array_agg"}};

const std::unordered_map<std::string, std::string> SubstraitParser::typeMap_ = {
{"bool", "BOOLEAN"},
Expand Down
4 changes: 2 additions & 2 deletions ep/build-velox/src/get_velox.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

set -exu

VELOX_REPO=https://github.com/oap-project/velox.git
VELOX_BRANCH=2024_02_28
VELOX_REPO=https://github.com/ulysses-you/velox.git
VELOX_BRANCH=setagg
VELOX_HOME=""

#Set on run gluten on HDFS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,6 @@ trait BackendSettingsApi {
def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = false

def shouldRewriteTypedImperativeAggregate(): Boolean = false

def shouldRewriteCollectSet(): Boolean = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ object ColumnarOverrideRules {
def rewriteSparkPlanRule(): Rule[SparkPlan] = {
val rewriteRules = Seq(
RewriteMultiChildrenCount,
RewriteCollectSet,
RewriteTypedImperativeAggregate,
PullOutPreProject,
PullOutPostProject)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.extension

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.utils.PullOutProjectHelper

import org.apache.spark.sql.catalyst.expressions.{And, IsNotNull}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectSet, Complete, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec

/**
* This rule Add `IsNotNull` to skip null value before going to native collect_set
*
* TODO: remove this rule once Velox collect_set skip null value
*/
object RewriteCollectSet extends Rule[SparkPlan] with PullOutProjectHelper {
private lazy val shouldRewriteCollect =
BackendsApiManager.getSettings.shouldRewriteCollectSet()

private def shouldRewrite(ae: AggregateExpression): Boolean = {
ae.aggregateFunction match {
case _: CollectSet =>
ae.mode match {
case Partial | Complete => true
case _ => false
}
case _ => false
}
}

private def rewriteCollectFilter(aggExprs: Seq[AggregateExpression]): Seq[AggregateExpression] = {
aggExprs
.map {
aggExpr =>
if (shouldRewrite(aggExpr)) {
val newFilter =
(aggExpr.filter ++ Seq(IsNotNull(aggExpr.aggregateFunction.children.head)))
.reduce(And)
aggExpr.copy(filter = Option(newFilter))
} else {
aggExpr
}
}
}

override def apply(plan: SparkPlan): SparkPlan = {
if (!shouldRewriteCollect) {
return plan
}

plan match {
case agg: BaseAggregateExec if agg.aggregateExpressions.exists(shouldRewrite) =>
val newAggExprs = rewriteCollectFilter(agg.aggregateExpressions)
copyBaseAggregateExec(agg)(newAggregateExpressions = newAggExprs)

case _ => plan
}
}
}

0 comments on commit 98ab12b

Please sign in to comment.