Skip to content

Commit

Permalink
[GLUTEN-7359][VL] Optimize string in partial project (#7592)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh authored Oct 28, 2024
1 parent f5d42f1 commit 932d8a2
Show file tree
Hide file tree
Showing 7 changed files with 626 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,26 @@ package org.apache.gluten.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches}
import org.apache.gluten.expression.ExpressionUtils
import org.apache.gluten.expression.{ArrowProjection, ExpressionUtils}
import org.apache.gluten.extension.{GlutenPlan, ValidationResult}
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.vectorized.ArrowWritableColumnVector
import org.apache.gluten.vectorized.{ArrowColumnarRow, ArrowWritableColumnVector}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, MutableProjection, NamedExpression, NaNvl, ScalaUDF, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, NamedExpression, NaNvl, ScalaUDF}
import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, WritableColumnVector}
import org.apache.spark.sql.hive.HiveUdfUtil
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, NullType, ShortType, StringType, TimestampType, YearMonthIntervalType}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

import scala.collection.mutable.ListBuffer

/**
* By rule <PartialProhectRule>, the project not offload-able that is changed to
* By rule <PartialProjectRule>, the project not offload-able that is changed to
* ProjectExecTransformer + ColumnarPartialProjectExec e.g. sum(myudf(a) + b + hash(c)), child is
* (a, b, c) ColumnarPartialProjectExec (a, b, c, myudf(a) as _SparkPartialProject1),
* ProjectExecTransformer(_SparkPartialProject1 + b + hash(c))
Expand All @@ -64,12 +63,12 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(

@transient override lazy val metrics = Map(
"time" -> SQLMetrics.createTimingMetric(sparkContext, "total time of partial project"),
"column_to_row_time" -> SQLMetrics.createTimingMetric(
"velox_to_arrow_time" -> SQLMetrics.createTimingMetric(
sparkContext,
"time of velox to Arrow ColumnarBatch or UnsafeRow"),
"row_to_column_time" -> SQLMetrics.createTimingMetric(
"time of velox to Arrow ColumnarBatch"),
"arrow_to_velox_time" -> SQLMetrics.createTimingMetric(
sparkContext,
"time of Arrow ColumnarBatch or UnsafeRow to velox")
"time of Arrow ColumnarBatch to velox")
)

override def output: Seq[Attribute] = child.output ++ replacedAliasUdf.map(_.toAttribute)
Expand Down Expand Up @@ -111,22 +110,26 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
}

private def getProjectIndexInChildOutput(exprs: Seq[Expression]): Unit = {
exprs.foreach {
exprs.forall {
case a: AttributeReference =>
val index = child.output.indexWhere(s => s.exprId.equals(a.exprId))
// Some child operator as HashAggregateTransformer will not have udf child column
if (index < 0) {
UDFAttrNotExists = true
log.debug(s"Expression $a should exist in child output ${child.output}")
return
false
} else if (!validateDataType(a.dataType)) {
hasUnsupportedDataType = true
log.debug(s"Expression $a contains unsupported data type ${a.dataType}")
false
} else if (!projectIndexInChild.contains(index)) {
projectAttributes.append(a.toAttribute)
projectIndexInChild.append(index)
}
case p => getProjectIndexInChildOutput(p.children)
true
} else true
case p =>
getProjectIndexInChildOutput(p.children)
true
}
}

Expand All @@ -150,7 +153,7 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
return ValidationResult.failed("No UDF")
}
if (replacedAliasUdf.size > original.output.size) {
// e.g. udf1(col) + udf2(col), it will introduce 2 cols for r2c
// e.g. udf1(col) + udf2(col), it will introduce 2 cols for a2c
return ValidationResult.failed("Number of RowToColumn columns is more than ProjectExec")
}
if (!original.projectList.forall(validateExpression(_))) {
Expand All @@ -168,9 +171,8 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(

override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
val totalTime = longMetric("time")
val c2r = longMetric("column_to_row_time")
val r2c = longMetric("row_to_column_time")
val isMutable = canUseMutableProjection()
val c2a = longMetric("velox_to_arrow_time")
val a2c = longMetric("arrow_to_velox_time")
child.executeColumnar().mapPartitions {
batches =>
val res: Iterator[Iterator[ColumnarBatch]] = new Iterator[Iterator[ColumnarBatch]] {
Expand All @@ -183,9 +185,8 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
} else {
val start = System.currentTimeMillis()
val childData = ColumnarBatches.select(batch, projectIndexInChild.toArray)
val projectedBatch = if (isMutable) {
getProjectedBatchArrow(childData, c2r, r2c)
} else getProjectedBatch(childData, c2r, r2c)
val projectedBatch = getProjectedBatchArrow(childData, c2a, a2c)

val batchIterator = projectedBatch.map {
b =>
if (b.numCols() != 0) {
Expand Down Expand Up @@ -214,60 +215,12 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(
}
}

// scalastyle:off line.size.limit
// String type cannot use MutableProjection
// Otherwise will throw java.lang.UnsupportedOperationException: Datatype not supported StringType
// at org.apache.spark.sql.execution.vectorized.MutableColumnarRow.update(MutableColumnarRow.java:224)
// at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificMutableProjection.apply(Unknown Source)
// scalastyle:on line.size.limit
private def canUseMutableProjection(): Boolean = {
replacedAliasUdf.forall(
r =>
r.dataType match {
case StringType | BinaryType => false
case _ => true
})
}

/**
* add c2r and r2c for unsupported expression child data c2r get Iterator[InternalRow], then call
* Spark project, then r2c
*/
private def getProjectedBatch(
childData: ColumnarBatch,
c2r: SQLMetric,
r2c: SQLMetric): Iterator[ColumnarBatch] = {
// select part of child output and child data
val proj = UnsafeProjection.create(replacedAliasUdf, projectAttributes.toSeq)
val numOutputRows = new SQLMetric("numOutputRows")
val numInputBatches = new SQLMetric("numInputBatches")
val rows = VeloxColumnarToRowExec
.toRowIterator(
Iterator.single[ColumnarBatch](childData),
projectAttributes.toSeq,
numOutputRows,
numInputBatches,
c2r)
.map(proj)

val schema =
SparkShimLoader.getSparkShims.structFromAttributes(replacedAliasUdf.map(_.toAttribute))
RowToVeloxColumnarExec.toColumnarBatchIterator(
rows,
schema,
numOutputRows,
numInputBatches,
r2c,
childData.numRows())
// TODO: should check the size <= 1, but now it has bug, will change iterator to empty
}

private def getProjectedBatchArrow(
childData: ColumnarBatch,
c2a: SQLMetric,
a2c: SQLMetric): Iterator[ColumnarBatch] = {
// select part of child output and child data
val proj = MutableProjection.create(replacedAliasUdf, projectAttributes.toSeq)
val proj = ArrowProjection.create(replacedAliasUdf, projectAttributes.toSeq)
val numRows = childData.numRows()
val start = System.currentTimeMillis()
val arrowBatch = if (childData.numCols() == 0) {
Expand All @@ -279,14 +232,14 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(

val schema =
SparkShimLoader.getSparkShims.structFromAttributes(replacedAliasUdf.map(_.toAttribute))
val vectors: Array[WritableColumnVector] = ArrowWritableColumnVector
val vectors: Array[ArrowWritableColumnVector] = ArrowWritableColumnVector
.allocateColumns(numRows, schema)
.map {
vector =>
vector.setValueCount(numRows)
vector
}
val targetRow = new MutableColumnarRow(vectors)
val targetRow = new ArrowColumnarRow(vectors)
for (i <- 0 until numRows) {
targetRow.rowId = i
proj.target(targetRow).apply(arrowBatch.getRow(i))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ abstract class UDFPartialProjectSuite extends WholeStageTransformerSuite {
spark.udf.register("plus_one", plusOne)
val noArgument = udf(() => 15)
spark.udf.register("no_argument", noArgument)
val concat = udf((x: String) => x + "_concat")
spark.udf.register("concat_concat", concat)

}

Expand Down Expand Up @@ -139,4 +141,10 @@ abstract class UDFPartialProjectSuite extends WholeStageTransformerSuite {
}
}

test("test concat with string") {
runQueryAndCompare("SELECT concat_concat(l_comment), hash(l_partkey) from lineitem") {
checkGlutenOperatorMatch[ColumnarPartialProjectExec]
}
}

}
Loading

0 comments on commit 932d8a2

Please sign in to comment.