Skip to content

Commit

Permalink
[SPARK-48498][SQL] Always do char padding in predicates
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

For some data sources, CHAR type padding is not applied on both the write and read sides (by disabling `spark.sql.readSideCharPadding`), as a different SQL flavor, which is similar to MySQL: https://dev.mysql.com/doc/refman/8.0/en/char.html

However, there is a bug in Spark that we always pad the string literal when comparing CHAR type and STRING literals, which assumes the CHAR type columns are always padded, either on the write side or read side. This is not always true.

This PR makes Spark always pad the CHAR type columns when comparing with string literals, to satisfy the CHAR type semantic.

### Why are the changes needed?

bug fix if people disable read side char padding

### Does this PR introduce _any_ user-facing change?

Yes. After this PR, comparing CHAR type with STRING literals follows the CHAR semantic, while before it mostly returns false.

### How was this patch tested?

new tests

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#46832 from cloud-fan/char.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
cloud-fan authored and jackylee-ch committed Jun 12, 2024
1 parent a00c115 commit dbc4d0a
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4175,6 +4175,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val LEGACY_NO_CHAR_PADDING_IN_PREDICATE = buildConf("spark.sql.legacy.noCharPaddingInPredicate")
.internal()
.doc("When true, Spark will not apply char type padding for CHAR type columns in string " +
s"comparison predicates, when '${READ_SIDE_CHAR_PADDING.key}' is false.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val CLI_PRINT_HEADER =
buildConf("spark.sql.cli.print.header")
.doc("When set to true, spark-sql CLI prints the names of the columns in query output.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{CharType, Metadata, StringType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -66,9 +67,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols)
})
}
paddingForStringComparison(newPlan)
paddingForStringComparison(newPlan, padCharCol = false)
} else {
paddingForStringComparison(plan)
paddingForStringComparison(
plan, padCharCol = !conf.getConf(SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE))
}
}

Expand All @@ -90,7 +92,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
}
}

private def paddingForStringComparison(plan: LogicalPlan): LogicalPlan = {
private def paddingForStringComparison(plan: LogicalPlan, padCharCol: Boolean): LogicalPlan = {
plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) {
case operator => operator.transformExpressionsUpWithPruning(
_.containsAnyPattern(BINARY_COMPARISON, IN)) {
Expand All @@ -99,12 +101,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
// String literal is treated as char type when it's compared to a char type column.
// We should pad the shorter one to the longer length.
case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if lit.foldable =>
padAttrLitCmp(e, attr.metadata, lit).map { newChildren =>
padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren =>
b.withNewChildren(newChildren)
}.getOrElse(b)

case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if lit.foldable =>
padAttrLitCmp(e, attr.metadata, lit).map { newChildren =>
padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren =>
b.withNewChildren(newChildren.reverse)
}.getOrElse(b)

Expand All @@ -117,9 +119,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
val literalCharLengths = literalChars.map(_.numChars())
val targetLen = (length +: literalCharLengths).max
Some(i.copy(
value = addPadding(e, length, targetLen),
value = addPadding(e, length, targetLen, alwaysPad = padCharCol),
list = list.zip(literalCharLengths).map {
case (lit, charLength) => addPadding(lit, charLength, targetLen)
case (lit, charLength) =>
addPadding(lit, charLength, targetLen, alwaysPad = false)
} ++ nulls.map(Literal.create(_, StringType))))
case _ => None
}.getOrElse(i)
Expand Down Expand Up @@ -162,6 +165,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
private def padAttrLitCmp(
expr: Expression,
metadata: Metadata,
padCharCol: Boolean,
lit: Expression): Option[Seq[Expression]] = {
if (expr.dataType == StringType) {
CharVarcharUtils.getRawType(metadata).flatMap {
Expand All @@ -174,7 +178,14 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
if (length < stringLitLen) {
Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit))
} else if (length > stringLitLen) {
Some(Seq(expr, StringRPad(lit, Literal(length))))
val paddedExpr = if (padCharCol) {
StringRPad(expr, Literal(length))
} else {
expr
}
Some(Seq(paddedExpr, StringRPad(lit, Literal(length))))
} else if (padCharCol) {
Some(Seq(StringRPad(expr, Literal(length)), lit))
} else {
None
}
Expand All @@ -186,7 +197,15 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
}
}

private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = {
if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr
private def addPadding(
expr: Expression,
charLength: Int,
targetLength: Int,
alwaysPad: Boolean): Expression = {
if (targetLength > charLength) {
StringRPad(expr, Literal(targetLength))
} else if (alwaysPad) {
StringRPad(expr, Literal(charLength))
} else expr
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,34 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa
}
}
}

test("SPARK-48498: always do char padding in predicates") {
import testImplicits._
withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false") {
withTempPath { dir =>
withTable("t") {
Seq(
"12" -> "12",
"12" -> "12 ",
"12 " -> "12",
"12 " -> "12 "
).toDF("c1", "c2").write.format(format).save(dir.toString)
sql(s"CREATE TABLE t (c1 CHAR(3), c2 STRING) USING $format LOCATION '$dir'")
// Comparing CHAR column with STRING column directly compares the stored value.
checkAnswer(
sql("SELECT c1 = c2 FROM t"),
Seq(Row(true), Row(false), Row(false), Row(true))
)
// No matter the CHAR type value is padded or not in the storage, we should always pad it
// before comparison with STRING literals.
checkAnswer(
sql("SELECT c1 = '12', c1 = '12 ', c1 = '12 ' FROM t WHERE c2 = '12'"),
Seq(Row(true, true, true), Row(true, true, true))
)
}
}
}
}
}

class DSV2CharVarcharTestSuite extends CharVarcharTestSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,11 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite {
protected def testQuery(tpcdsGroup: String, query: String, suffix: String = ""): Unit = {
val queryString = resourceToString(s"$tpcdsGroup/$query.sql",
classLoader = Thread.currentThread().getContextClassLoader)
// Disable char/varchar read-side handling for better performance.
withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {
withSQLConf(
// Disable char/varchar read-side handling for better performance.
SQLConf.READ_SIDE_CHAR_PADDING.key -> "false",
SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {
val qe = sql(queryString).queryExecution
val plan = qe.executedPlan
val explain = normalizeLocation(normalizeIds(qe.explainString(FormattedMode)))
Expand Down

0 comments on commit dbc4d0a

Please sign in to comment.