Skip to content

Commit

Permalink
[SPARK-50895][SQL] Create common interface for expressions which prod…
Browse files Browse the repository at this point in the history
…uce default string type

### What changes were proposed in this pull request?
Introducing a new interface `DefaultStringProducingExpression` which should be inherited  by all expressions that produce default string type as their output.

### Why are the changes needed?
Because right now all of these expressions have hardcoded default string type and it will be infinitely easier to manipulate these expression if they had a common supertype.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This is a refactoring change only, so existing tests should be sufficient.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #49576 from stefankandic/IDefaultProducingExpr.

Authored-by: Stefan Kandic <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
stefankandic authored and MaxGekk committed Jan 21, 2025
1 parent 001e244 commit fab0cca
Show file tree
Hide file tree
Showing 18 changed files with 130 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -64,6 +63,7 @@ case class CallMethodViaReflection(
children: Seq[Expression],
failOnError: Boolean = true)
extends Nondeterministic
with DefaultStringProducingExpression
with CodegenFallback
with QueryErrorsBase {

Expand Down Expand Up @@ -139,7 +139,6 @@ case class CallMethodViaReflection(
}

override def nullable: Boolean = true
override val dataType: DataType = SQLConf.get.defaultStringType
override protected def initializeInternal(partitionIndex: Int): Unit = {}

override protected def evalInternal(input: InternalRow): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1462,3 +1462,10 @@ case class MultiCommutativeOp(

override protected final def otherCopyArgs: Seq[AnyRef] = originalRoot :: Nil
}

/**
* Trait for expressions whose data type should be a default string type.
*/
trait DefaultStringProducingExpression extends Expression {
override def dataType: DataType = SQLConf.get.defaultStringType
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String

/**
Expand All @@ -30,9 +29,10 @@ import org.apache.spark.unsafe.types.UTF8String
* - It prints binary values (either from column or struct field) using the hex format.
*/
case class ToPrettyString(child: Expression, timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with ToStringBase {

override def dataType: DataType = StringType
extends UnaryExpression
with DefaultStringProducingExpression
with TimeZoneAwareExpression
with ToStringBase {

override def nullable: Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,15 @@ case class ResolvedCollation(collationName: String) extends LeafExpression with
group = "string_funcs")
// scalastyle:on line.contains.tab
case class Collation(child: Expression)
extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes {
extends UnaryExpression
with RuntimeReplaceable
with ExpectsInputTypes
with DefaultStringProducingExpression {
override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild)
override lazy val replacement: Expression = {
val collationId = child.dataType.asInstanceOf[StringType].collationId
val fullyQualifiedCollationName = CollationFactory.fullyQualifiedName(collationId)
Literal.create(fullyQualifiedCollationName, SQLConf.get.defaultStringType)
Literal.create(fullyQualifiedCollationName, dataType)
}
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ case class CsvToStructs(
case class SchemaOfCsv(
child: Expression,
options: Map[String, String])
extends UnaryExpression with RuntimeReplaceable with QueryErrorsBase {
extends UnaryExpression
with RuntimeReplaceable
with DefaultStringProducingExpression
with QueryErrorsBase {

def this(child: Expression) = this(child, Map.empty[String, String])

def this(child: Expression, options: Expression) = this(
child = child,
options = ExprUtils.convertToMapData(options))

override def dataType: DataType = SQLConf.get.defaultStringType

override def nullable: Boolean = false

@transient
Expand Down Expand Up @@ -212,7 +213,10 @@ case class StructsToCsv(
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes {
extends UnaryExpression
with TimeZoneAwareExpression
with DefaultStringProducingExpression
with ExpectsInputTypes {
override def nullIntolerant: Boolean = true
override def nullable: Boolean = true

Expand Down Expand Up @@ -266,8 +270,6 @@ case class StructsToCsv(
(row: Any) => UTF8String.fromString(gen.writeToString(row.asInstanceOf[InternalRow]))
}

override def dataType: DataType = SQLConf.get.defaultStringType

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression {
""",
group = "datetime_funcs",
since = "3.1.0")
case class CurrentTimeZone() extends LeafExpression with Unevaluable {
case class CurrentTimeZone()
extends LeafExpression
with DefaultStringProducingExpression
with Unevaluable {
override def nullable: Boolean = false
override def dataType: DataType = SQLConf.get.defaultStringType
override def prettyName: String = "current_timezone"
final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE)
}
Expand Down Expand Up @@ -918,10 +920,9 @@ case class WeekOfYear(child: Expression) extends GetDateField {
""",
group = "datetime_funcs",
since = "4.0.0")
case class MonthName(child: Expression) extends GetDateField {
case class MonthName(child: Expression) extends GetDateField with DefaultStringProducingExpression {
override val func = DateTimeUtils.getMonthName
override val funcName = "getMonthName"
override def dataType: DataType = StringType
override protected def withNewChildInternal(newChild: Expression): MonthName =
copy(child = newChild)
}
Expand All @@ -935,12 +936,11 @@ case class MonthName(child: Expression) extends GetDateField {
""",
group = "datetime_funcs",
since = "4.0.0")
case class DayName(child: Expression) extends GetDateField {
case class DayName(child: Expression) extends GetDateField with DefaultStringProducingExpression {
override val func = DateTimeUtils.getDayName
override val funcName = "getDayName"

override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
override def dataType: DataType = SQLConf.get.defaultStringType
override protected def withNewChildInternal(newChild: Expression): DayName =
copy(child = newChild)
}
Expand All @@ -963,13 +963,14 @@ case class DayName(child: Expression) extends GetDateField {
since = "1.5.0")
// scalastyle:on line.size.limit
case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None)
extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes {
extends BinaryExpression
with TimestampFormatterHelper
with ImplicitCastInputTypes
with DefaultStringProducingExpression {
override def nullIntolerant: Boolean = true

def this(left: Expression, right: Expression) = this(left, right, None)

override def dataType: DataType = SQLConf.get.defaultStringType

override def inputTypes: Seq[AbstractDataType] =
Seq(TimestampType, StringTypeWithCollation(supportsTrimCollation = true))

Expand Down Expand Up @@ -1441,7 +1442,10 @@ abstract class UnixTime extends ToTimestamp {
since = "1.5.0")
// scalastyle:on line.size.limit
case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[String] = None)
extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes {
extends BinaryExpression
with TimestampFormatterHelper
with ImplicitCastInputTypes
with DefaultStringProducingExpression {
override def nullIntolerant: Boolean = true

def this(sec: Expression, format: Expression) = this(sec, format, None)
Expand All @@ -1455,7 +1459,6 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
this(unix, Literal(TimestampFormatter.defaultPattern()))
}

override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ import org.apache.spark.util.ArrayImplicits._
since = "1.5.0",
group = "hash_funcs")
case class Md5(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes {
extends UnaryExpression
with ImplicitCastInputTypes
with DefaultStringProducingExpression {
override def nullIntolerant: Boolean = true

override def dataType: DataType = SQLConf.get.defaultStringType

override def inputTypes: Seq[DataType] = Seq(BinaryType)

protected override def nullSafeEval(input: Any): Any =
Expand Down Expand Up @@ -102,10 +102,12 @@ case class Md5(child: Expression)
group = "hash_funcs")
// scalastyle:on line.size.limit
case class Sha2(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes with Serializable {
extends BinaryExpression
with ImplicitCastInputTypes
with Serializable
with DefaultStringProducingExpression {
override def nullIntolerant: Boolean = true

override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = true

override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
Expand Down Expand Up @@ -169,11 +171,11 @@ case class Sha2(left: Expression, right: Expression)
since = "1.5.0",
group = "hash_funcs")
case class Sha1(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes {
extends UnaryExpression
with ImplicitCastInputTypes
with DefaultStringProducingExpression {
override def nullIntolerant: Boolean = true

override def dataType: DataType = SQLConf.get.defaultStringType

override def inputTypes: Seq[DataType] = Seq(BinaryType)

protected override def nullSafeEval(input: Any): Any =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, LongType}
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -36,12 +35,13 @@ import org.apache.spark.unsafe.types.UTF8String
since = "1.5.0",
group = "misc_funcs")
// scalastyle:on whitespace.end.of.line
case class InputFileName() extends LeafExpression with Nondeterministic {
case class InputFileName()
extends LeafExpression
with Nondeterministic
with DefaultStringProducingExpression {

override def nullable: Boolean = false

override def dataType: DataType = SQLConf.get.defaultStringType

override def prettyName: String = "input_file_name"

override protected def initializeInternal(partitionIndex: Int): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ import org.apache.spark.unsafe.types.UTF8String
group = "json_funcs",
since = "1.5.0")
case class GetJsonObject(json: Expression, path: Expression)
extends BinaryExpression with ExpectsInputTypes {
extends BinaryExpression
with ExpectsInputTypes
with DefaultStringProducingExpression {

override def left: Expression = json
override def right: Expression = path
override def inputTypes: Seq[AbstractDataType] =
Seq(
StringTypeWithCollation(supportsTrimCollation = true),
StringTypeWithCollation(supportsTrimCollation = true))
override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = true
override def prettyName: String = "get_json_object"

Expand Down Expand Up @@ -382,6 +383,7 @@ case class StructsToJson(
with RuntimeReplaceable
with ExpectsInputTypes
with TimeZoneAwareExpression
with DefaultStringProducingExpression
with QueryErrorsBase {

override def nullable: Boolean = true
Expand All @@ -401,8 +403,6 @@ case class StructsToJson(
@transient
private lazy val inputSchema = child.dataType

override def dataType: DataType = SQLConf.get.defaultStringType

override def checkInputDataTypes(): TypeCheckResult = inputSchema match {
case dt @ (_: StructType | _: MapType | _: ArrayType | _: VariantType) =>
JacksonUtils.verifyType(prettyName, dt)
Expand Down Expand Up @@ -453,6 +453,7 @@ case class SchemaOfJson(
options: Map[String, String])
extends UnaryExpression
with RuntimeReplaceable
with DefaultStringProducingExpression
with QueryErrorsBase {

def this(child: Expression) = this(child, Map.empty[String, String])
Expand All @@ -461,8 +462,6 @@ case class SchemaOfJson(
child = child,
options = ExprUtils.convertToMapData(options))

override def dataType: DataType = SQLConf.get.defaultStringType

override def nullable: Boolean = false

@transient
Expand Down Expand Up @@ -573,11 +572,12 @@ case class LengthOfJsonArray(child: Expression)
case class JsonObjectKeys(child: Expression)
extends UnaryExpression
with ExpectsInputTypes
with RuntimeReplaceable {
with RuntimeReplaceable
with DefaultStringProducingExpression {

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType)
override def dataType: DataType = ArrayType(super.dataType)
override def nullable: Boolean = true
override def prettyName: String = "json_object_keys"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1005,10 +1005,12 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
group = "math_funcs")
// scalastyle:on line.size.limit
case class Bin(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with Serializable {
extends UnaryExpression
with ImplicitCastInputTypes
with Serializable
with DefaultStringProducingExpression {
override def nullIntolerant: Boolean = true
override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = SQLConf.get.defaultStringType

protected override def nullSafeEval(input: Any): Any =
UTF8String.toBinaryString(input.asInstanceOf[Long])
Expand Down Expand Up @@ -1114,15 +1116,17 @@ object Hex {
since = "1.5.0",
group = "math_funcs")
case class Hex(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes {
extends UnaryExpression
with ImplicitCastInputTypes
with DefaultStringProducingExpression {
override def nullIntolerant: Boolean = true

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, BinaryType, StringTypeWithCollation(supportsTrimCollation = true)))

override def dataType: DataType = child.dataType match {
case st: StringType => st
case _ => SQLConf.get.defaultStringType
case _ => super.dataType
}

protected override def nullSafeEval(num: Any): Any = child.dataType match {
Expand Down
Loading

0 comments on commit fab0cca

Please sign in to comment.