Skip to content

Commit

Permalink
fix: Fall back to Spark when hashing decimals with precision > 18 (#1325
Browse files Browse the repository at this point in the history
)

* fall back to Spark when hashing decimals with precision > 18

* murmur3 checks

* refactor

* fix

* address feedback
  • Loading branch information
andygrove authored Jan 29, 2025
1 parent 07274e8 commit e964947
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 42 deletions.
32 changes: 3 additions & 29 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2176,35 +2176,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
None
}

case Murmur3Hash(children, seed) =>
val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType))
if (firstUnSupportedInput.isDefined) {
withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}")
return None
}
val exprs = children.map(exprToProtoInternal(_, inputs, binding))
val seedBuilder = ExprOuterClass.Literal
.newBuilder()
.setDatatype(serializeDataType(IntegerType).get)
.setIntVal(seed)
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
// the seed is put at the end of the arguments
scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*)

case XxHash64(children, seed) =>
val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType))
if (firstUnSupportedInput.isDefined) {
withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}")
return None
}
val exprs = children.map(exprToProtoInternal(_, inputs, binding))
val seedBuilder = ExprOuterClass.Literal
.newBuilder()
.setDatatype(serializeDataType(LongType).get)
.setLongVal(seed)
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
// the seed is put at the end of the arguments
scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*)
case _: Murmur3Hash => CometMurmur3Hash.convert(expr, inputs, binding)

case _: XxHash64 => CometXxHash64.convert(expr, inputs, binding)

case Sha2(left, numBits) =>
if (!numBits.foldable) {
Expand Down
85 changes: 85 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/hash.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, XxHash64}
import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType}

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, scalarExprToProtoWithReturnType, serializeDataType, supportedDataType}

object CometXxHash64 extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
if (!HashUtils.isSupportedType(expr)) {
return None
}
val hash = expr.asInstanceOf[XxHash64]
val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding))
val seedBuilder = ExprOuterClass.Literal
.newBuilder()
.setDatatype(serializeDataType(LongType).get)
.setLongVal(hash.seed)
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
// the seed is put at the end of the arguments
scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*)
}
}

object CometMurmur3Hash extends CometExpressionSerde {
override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
if (!HashUtils.isSupportedType(expr)) {
return None
}
val hash = expr.asInstanceOf[Murmur3Hash]
val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding))
val seedBuilder = ExprOuterClass.Literal
.newBuilder()
.setDatatype(serializeDataType(IntegerType).get)
.setIntVal(hash.seed)
val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
// the seed is put at the end of the arguments
scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*)
}
}

private object HashUtils {
def isSupportedType(expr: Expression): Boolean = {
for (child <- expr.children) {
child.dataType match {
case dt: DecimalType if dt.precision > 18 =>
// Spark converts decimals with precision > 18 into
// Java BigDecimal before hashing
withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
return false
case dt if !supportedDataType(dt) =>
withInfo(expr, s"Unsupported datatype $dt")
return false
case _ =>
}
}
true
}
}
52 changes: 39 additions & 13 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1929,19 +1929,45 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("hash functions with decimal input") {
withTable("t1", "t2") {
// Apache Spark: if it's a small decimal, i.e. precision <= 18, turn it into long and hash it.
// Else, turn it into bytes and hash it.
sql("create table t1(c1 decimal(18, 2)) using parquet")
sql("insert into t1 values(1.23), (-1.23), (0.0), (null)")
checkSparkAnswerAndOperator("select c1, hash(c1), xxhash64(c1) from t1 order by c1")

// TODO: comet hash function is not compatible with spark for decimal with precision greater than 18.
// https://github.com/apache/datafusion-comet/issues/1294
// sql("create table t2(c1 decimal(20, 2)) using parquet")
// sql("insert into t2 values(1.23), (-1.23), (0.0), (null)")
// checkSparkAnswerAndOperator("select c1, hash(c1), xxhash64(c1) from t2 order by c1")
test("hash function with decimal input") {
val testPrecisionScales: Seq[(Int, Int)] = Seq(
(1, 0),
(17, 2),
(18, 2),
(19, 2),
(DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE - 1))
for ((p, s) <- testPrecisionScales) {
withTable("t1") {
sql(s"create table t1(c1 decimal($p, $s)) using parquet")
sql("insert into t1 values(1.23), (-1.23), (0.0), (null)")
if (p <= 18) {
checkSparkAnswerAndOperator("select c1, hash(c1) from t1 order by c1")
} else {
// not supported natively yet
checkSparkAnswer("select c1, hash(c1) from t1 order by c1")
}
}
}
}

test("xxhash64 function with decimal input") {
val testPrecisionScales: Seq[(Int, Int)] = Seq(
(1, 0),
(17, 2),
(18, 2),
(19, 2),
(DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE - 1))
for ((p, s) <- testPrecisionScales) {
withTable("t1") {
sql(s"create table t1(c1 decimal($p, $s)) using parquet")
sql("insert into t1 values(1.23), (-1.23), (0.0), (null)")
if (p <= 18) {
checkSparkAnswerAndOperator("select c1, xxhash64(c1) from t1 order by c1")
} else {
// not supported natively yet
checkSparkAnswer("select c1, xxhash64(c1) from t1 order by c1")
}
}
}
}

Expand Down

0 comments on commit e964947

Please sign in to comment.