Skip to content

Commit

Permalink
[SPARK-48831][CONNECT] Make default column name of cast compatible …
Browse files Browse the repository at this point in the history
…with Spark Classic

### What changes were proposed in this pull request?

I think there are two issues regarding the default column name of `cast`:
1, It seems unclear that when the name is the input column or `CAST(...)`, e.g. in Spark Classic,
```
scala> spark.range(1).select(col("id").cast("string"), lit(1).cast("string"), col("id").cast("long"), lit(1).cast("long")).printSchema
warning: 1 deprecation (since 2.13.3); for details, enable `:setting -deprecation` or `:replay -deprecation`
root
 |-- id: string (nullable = false)
 |-- CAST(1 AS STRING): string (nullable = false)
 |-- id: long (nullable = false)
 |-- CAST(1 AS BIGINT): long (nullable = false)
```

2, the column name is not consistent between Spark Connect and Spark Classic.

This PR aims to resolve the second issue, that is, making default column name of `cast` compatible with Spark Classic, by comparing with classic implementation
https://github.com/apache/spark/blob/9cf6dc873ff34412df6256cdc7613eed40716570/sql/core/src/main/scala/org/apache/spark/sql/Column.scala#L1208-L1212

### Why are the changes needed?
the default column name is not consistent with the spark classic

### Does this PR introduce _any_ user-facing change?
yes,

spark classic:
```
In [2]: spark.range(1).select(sf.lit(b'123').cast("STRING"), sf.lit(123).cast("STRING"), sf.lit(123).cast("LONG"), sf.lit(123).cast("DOUBLE")).show()
+-------------------------+-------------------+-------------------+-------------------+
|CAST(X'313233' AS STRING)|CAST(123 AS STRING)|CAST(123 AS BIGINT)|CAST(123 AS DOUBLE)|
+-------------------------+-------------------+-------------------+-------------------+
|                      123|                123|                123|              123.0|
+-------------------------+-------------------+-------------------+-------------------+
```

spark connect (before):
```
In [3]: spark.range(1).select(sf.lit(b'123').cast("STRING"), sf.lit(123).cast("STRING"), sf.lit(123).cast("LONG"), sf.lit(123).cast("DOUBLE")).show()
+---------+---+---+-----+
|X'313233'|123|123|  123|
+---------+---+---+-----+
|      123|123|123|123.0|
+---------+---+---+-----+
```

spark connect (after):
```
In [2]: spark.range(1).select(sf.lit(b'123').cast("STRING"), sf.lit(123).cast("STRING"), sf.lit(123).cast("LONG"), sf.lit(123).cast("DOUBLE")).show()
+-------------------------+-------------------+-------------------+-------------------+
|CAST(X'313233' AS STRING)|CAST(123 AS STRING)|CAST(123 AS BIGINT)|CAST(123 AS DOUBLE)|
+-------------------------+-------------------+-------------------+-------------------+
|                      123|                123|                123|              123.0|
+-------------------------+-------------------+-------------------+-------------------+
```

### How was this patch tested?
added test

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#47249 from zhengruifeng/py_fix_cast.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Jul 9, 2024
1 parent 6edfd66 commit 43b6718
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [ATAN2(cast(a#0 as double), b#0) AS ATAN2(a, b)#0]
Project [ATAN2(cast(a#0 as double), b#0) AS ATAN2(CAST(a AS DOUBLE), b)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [base64(cast(g#0 as binary)) AS base64(g)#0]
Project [base64(cast(g#0 as binary)) AS base64(CAST(g AS BINARY))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [crc32(cast(g#0 as binary)) AS crc32(g)#0L]
Project [crc32(cast(g#0 as binary)) AS crc32(CAST(g AS BINARY))#0L]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [static_invoke(StringDecode.decode(cast(g#0 as binary), UTF-8, false, false)) AS decode(g, UTF-8)#0]
Project [static_invoke(StringDecode.decode(cast(g#0 as binary), UTF-8, false, false)) AS decode(CAST(g AS BINARY), UTF-8)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [md5(cast(g#0 as binary)) AS md5(g)#0]
Project [md5(cast(g#0 as binary)) AS md5(CAST(g AS BINARY))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [sha1(cast(g#0 as binary)) AS sha1(g)#0]
Project [sha1(cast(g#0 as binary)) AS sha1(CAST(g AS BINARY))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [sha2(cast(g#0 as binary), 512) AS sha2(g, 512)#0]
Project [sha2(cast(g#0 as binary), 512) AS sha2(CAST(g AS BINARY), 512)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -2178,20 +2178,23 @@ class SparkConnectPlanner(
}

private def transformCast(cast: proto.Expression.Cast): Expression = {
val dataType = cast.getCastToTypeCase match {
val rawDataType = cast.getCastToTypeCase match {
case proto.Expression.Cast.CastToTypeCase.TYPE => transformDataType(cast.getType)
case _ => parser.parseDataType(cast.getTypeStr)
}
val mode = cast.getEvalMode match {
case proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY => Some(EvalMode.LEGACY)
case proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI => Some(EvalMode.ANSI)
case proto.Expression.Cast.EvalMode.EVAL_MODE_TRY => Some(EvalMode.TRY)
case _ => None
}
mode match {
case Some(m) => Cast(transformExpression(cast.getExpr), dataType, None, m)
case _ => Cast(transformExpression(cast.getExpr), dataType)
val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
val castExpr = cast.getEvalMode match {
case proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY =>
Cast(transformExpression(cast.getExpr), dataType, None, EvalMode.LEGACY)
case proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI =>
Cast(transformExpression(cast.getExpr), dataType, None, EvalMode.ANSI)
case proto.Expression.Cast.EvalMode.EVAL_MODE_TRY =>
Cast(transformExpression(cast.getExpr), dataType, None, EvalMode.TRY)
case _ =>
Cast(transformExpression(cast.getExpr), dataType)
}
castExpr.setTagValue(Cast.USER_SPECIFIED_CAST, ())
castExpr
}

private def transformUnresolvedRegex(regex: proto.Expression.UnresolvedRegex): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
transform(connectTestRelation.observe("my_metric", "id".protoAttr.cast("string"))))
},
errorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
parameters = Map("expr" -> "\"id AS id\""))
parameters = Map("expr" -> "\"CAST(id AS STRING) AS id\""))

val connectPlan2 =
connectTestRelation.observe(
Expand Down Expand Up @@ -1016,7 +1016,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
connectTestRelation.observe(Observation("my_metric"), "id".protoAttr.cast("string"))))
},
errorClass = "INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
parameters = Map("expr" -> "\"id AS id\""))
parameters = Map("expr" -> "\"CAST(id AS STRING) AS id\""))
}

test("Test RandomSplit") {
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,21 @@ def test_lambda_str_representation(self):
),
)

def test_cast_default_column_name(self):
cdf = self.connect.range(1).select(
CF.lit(b"123").cast("STRING"),
CF.lit(123).cast("STRING"),
CF.lit(123).cast("LONG"),
CF.lit(123).cast("DOUBLE"),
)
sdf = self.spark.range(1).select(
SF.lit(b"123").cast("STRING"),
SF.lit(123).cast("STRING"),
SF.lit(123).cast("LONG"),
SF.lit(123).cast("DOUBLE"),
)
self.assertEqual(cdf.columns, sdf.columns)


if __name__ == "__main__":
import unittest
Expand Down

0 comments on commit 43b6718

Please sign in to comment.