Skip to content

Commit

Permalink
add check for compatibility for decimal to decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
himadripal committed Nov 19, 2024
1 parent 3062de2 commit b8bf29e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ object CometCast {
case _ =>
Unsupported
}
case (_: DecimalType, _: DecimalType) =>
Compatible()
case (from: DecimalType, to: DecimalType) =>
if (to.precision < from.precision)
Incompatible() // datafusion looses precision https://github.com/apache/datafusion/issues/13492
else Compatible()
case (DataTypes.StringType, _) =>
canCastFromString(toType, timeZoneId, evalMode)
case (_, DataTypes.StringType) =>
Expand Down
17 changes: 6 additions & 11 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -896,22 +896,17 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("cast between decimals with different precision and scale") {
// cast between default Decimal(38, 18) to Decimal(7,2)
// cast between default Decimal(38, 18) to Decimal(6,2)
val values = Seq(BigDecimal("12345.6789"), BigDecimal("9876.5432"), BigDecimal("123.4567"))
val df = withNulls(values).toDF("a")
castTest(df, DataTypes.createDecimalType(7, 2))
}

test("cast between decimals with lower precision and scale") {
// cast between Decimal(10, 2) to Decimal(9,1)
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(9, 1))
val df = withNulls(values)
.toDF("b")
.withColumn("a", col("b").cast(DecimalType(6, 2)))
checkSparkAnswer(df)
}

test("cast between decimals with higher precision than source") {
// cast between Decimal(10, 2) to Decimal(10,4)
withSQLConf("spark.comet.explainFallback.enabled" -> "true") {
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4))
}
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4))
}

test("cast between decimals with negative precision") {
Expand Down

0 comments on commit b8bf29e

Please sign in to comment.