From fab0cca80eb47431b701ed5afe241d46345d70c7 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 21 Jan 2025 21:05:01 +0200 Subject: [PATCH] [SPARK-50895][SQL] Create common interface for expressions which produce default string type ### What changes were proposed in this pull request? Introducing a new interface `DefaultStringProducingExpression` which should be inherited by all expressions that produce default string type as their output. ### Why are the changes needed? Because right now all of these expressions have hardcoded default string type and it will be infinitely easier to manipulate these expression if they had a common supertype. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This is a refactoring change only, so existing tests should be sufficient. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49576 from stefankandic/IDefaultProducingExpr. Authored-by: Stefan Kandic Signed-off-by: Max Gekk --- .../expressions/CallMethodViaReflection.scala | 3 +- .../sql/catalyst/expressions/Expression.scala | 7 ++++ .../catalyst/expressions/ToPrettyString.scala | 8 ++--- .../expressions/collationExpressions.scala | 7 ++-- .../catalyst/expressions/csvExpressions.scala | 14 ++++---- .../expressions/datetimeExpressions.scala | 25 +++++++------- .../spark/sql/catalyst/expressions/hash.scala | 18 +++++----- .../catalyst/expressions/inputFileBlock.scala | 8 ++--- .../expressions/jsonExpressions.scala | 16 ++++----- .../expressions/mathExpressions.scala | 12 ++++--- .../spark/sql/catalyst/expressions/misc.scala | 33 +++++++++++-------- .../expressions/numberFormatExpressions.scala | 4 +-- .../expressions/randomExpressions.scala | 8 +++-- .../expressions/stringExpressions.scala | 24 ++++++-------- .../catalyst/expressions/urlExpressions.scala | 20 +++++++---- .../variant/variantExpressions.scala | 8 ++--- .../sql/catalyst/expressions/xml/xpath.scala | 12 ++++--- .../catalyst/expressions/xmlExpressions.scala | 6 ++-- 18 files changed, 130 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 4eb14fb9e7b86..cf34ceefdfee9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -64,6 +63,7 @@ case class CallMethodViaReflection( children: Seq[Expression], failOnError: Boolean = true) extends Nondeterministic + with DefaultStringProducingExpression with CodegenFallback with QueryErrorsBase { @@ -139,7 +139,6 @@ case class CallMethodViaReflection( } override def nullable: Boolean = true - override val dataType: DataType = SQLConf.get.defaultStringType override protected def initializeInternal(partitionIndex: Int): Unit = {} override protected def evalInternal(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4c83f92509ecd..a5b6a17c6ae64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -1462,3 +1462,10 @@ case class MultiCommutativeOp( override protected final def otherCopyArgs: Seq[AnyRef] = originalRoot :: Nil } + +/** + * Trait for expressions whose data type should be a default string type. + */ +trait DefaultStringProducingExpression extends Expression { + override def dataType: DataType = SQLConf.get.defaultStringType +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala index d18fa7c138927..e24b741f6e292 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyString.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String /** @@ -30,9 +29,10 @@ import org.apache.spark.unsafe.types.UTF8String * - It prints binary values (either from column or struct field) using the hex format. */ case class ToPrettyString(child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ToStringBase { - - override def dataType: DataType = StringType + extends UnaryExpression + with DefaultStringProducingExpression + with TimeZoneAwareExpression + with ToStringBase { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index 024bef08b5273..396bc160e0329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -151,12 +151,15 @@ case class ResolvedCollation(collationName: String) extends LeafExpression with group = "string_funcs") // scalastyle:on line.contains.tab case class Collation(child: Expression) - extends UnaryExpression with RuntimeReplaceable with ExpectsInputTypes { + extends UnaryExpression + with RuntimeReplaceable + with ExpectsInputTypes + with DefaultStringProducingExpression { override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild) override lazy val replacement: Expression = { val collationId = child.dataType.asInstanceOf[StringType].collationId val fullyQualifiedCollationName = CollationFactory.fullyQualifiedName(collationId) - Literal.create(fullyQualifiedCollationName, SQLConf.get.defaultStringType) + Literal.create(fullyQualifiedCollationName, dataType) } override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index 04fb9bc133c67..b87e07977427e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -143,7 +143,10 @@ case class CsvToStructs( case class SchemaOfCsv( child: Expression, options: Map[String, String]) - extends UnaryExpression with RuntimeReplaceable with QueryErrorsBase { + extends UnaryExpression + with RuntimeReplaceable + with DefaultStringProducingExpression + with QueryErrorsBase { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -151,8 +154,6 @@ case class SchemaOfCsv( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = SQLConf.get.defaultStringType - override def nullable: Boolean = false @transient @@ -212,7 +213,10 @@ case class StructsToCsv( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { + extends UnaryExpression + with TimeZoneAwareExpression + with DefaultStringProducingExpression + with ExpectsInputTypes { override def nullIntolerant: Boolean = true override def nullable: Boolean = true @@ -266,8 +270,6 @@ case class StructsToCsv( (row: Any) => UTF8String.fromString(gen.writeToString(row.asInstanceOf[InternalRow])) } - override def dataType: DataType = SQLConf.get.defaultStringType - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 81be40b3b6474..67d9aff947cfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -110,9 +110,11 @@ trait TimestampFormatterHelper extends TimeZoneAwareExpression { """, group = "datetime_funcs", since = "3.1.0") -case class CurrentTimeZone() extends LeafExpression with Unevaluable { +case class CurrentTimeZone() + extends LeafExpression + with DefaultStringProducingExpression + with Unevaluable { override def nullable: Boolean = false - override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "current_timezone" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) } @@ -918,10 +920,9 @@ case class WeekOfYear(child: Expression) extends GetDateField { """, group = "datetime_funcs", since = "4.0.0") -case class MonthName(child: Expression) extends GetDateField { +case class MonthName(child: Expression) extends GetDateField with DefaultStringProducingExpression { override val func = DateTimeUtils.getMonthName override val funcName = "getMonthName" - override def dataType: DataType = StringType override protected def withNewChildInternal(newChild: Expression): MonthName = copy(child = newChild) } @@ -935,12 +936,11 @@ case class MonthName(child: Expression) extends GetDateField { """, group = "datetime_funcs", since = "4.0.0") -case class DayName(child: Expression) extends GetDateField { +case class DayName(child: Expression) extends GetDateField with DefaultStringProducingExpression { override val func = DateTimeUtils.getDayName override val funcName = "getDayName" override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildInternal(newChild: Expression): DayName = copy(child = newChild) } @@ -963,13 +963,14 @@ case class DayName(child: Expression) extends GetDateField { since = "1.5.0") // scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes { + extends BinaryExpression + with TimestampFormatterHelper + with ImplicitCastInputTypes + with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true def this(left: Expression, right: Expression) = this(left, right, None) - override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeWithCollation(supportsTrimCollation = true)) @@ -1441,7 +1442,10 @@ abstract class UnixTime extends ToTimestamp { since = "1.5.0") // scalastyle:on line.size.limit case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[String] = None) - extends BinaryExpression with TimestampFormatterHelper with ImplicitCastInputTypes { + extends BinaryExpression + with TimestampFormatterHelper + with ImplicitCastInputTypes + with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true def this(sec: Expression, format: Expression) = this(sec, format, None) @@ -1455,7 +1459,6 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ this(unix, Literal(TimestampFormatter.defaultPattern())) } - override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 89d2259ea5c28..ac493d19df1b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -61,11 +61,11 @@ import org.apache.spark.util.ArrayImplicits._ since = "1.5.0", group = "hash_funcs") case class Md5(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { + extends UnaryExpression + with ImplicitCastInputTypes + with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[DataType] = Seq(BinaryType) protected override def nullSafeEval(input: Any): Any = @@ -102,10 +102,12 @@ case class Md5(child: Expression) group = "hash_funcs") // scalastyle:on line.size.limit case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes with Serializable { + extends BinaryExpression + with ImplicitCastInputTypes + with Serializable + with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) @@ -169,11 +171,11 @@ case class Sha2(left: Expression, right: Expression) since = "1.5.0", group = "hash_funcs") case class Sha1(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { + extends UnaryExpression + with ImplicitCastInputTypes + with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[DataType] = Seq(BinaryType) protected override def nullSafeEval(input: Any): Any = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 65eb995ff32ff..cf860724c1f60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -21,7 +21,6 @@ import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, LongType} import org.apache.spark.unsafe.types.UTF8String @@ -36,12 +35,13 @@ import org.apache.spark.unsafe.types.UTF8String since = "1.5.0", group = "misc_funcs") // scalastyle:on whitespace.end.of.line -case class InputFileName() extends LeafExpression with Nondeterministic { +case class InputFileName() + extends LeafExpression + with Nondeterministic + with DefaultStringProducingExpression { override def nullable: Boolean = false - override def dataType: DataType = SQLConf.get.defaultStringType - override def prettyName: String = "input_file_name" override protected def initializeInternal(partitionIndex: Int): Unit = {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e80f543f14eda..5e6da7ac41250 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -46,7 +46,9 @@ import org.apache.spark.unsafe.types.UTF8String group = "json_funcs", since = "1.5.0") case class GetJsonObject(json: Expression, path: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression + with ExpectsInputTypes + with DefaultStringProducingExpression { override def left: Expression = json override def right: Expression = path @@ -54,7 +56,6 @@ case class GetJsonObject(json: Expression, path: Expression) Seq( StringTypeWithCollation(supportsTrimCollation = true), StringTypeWithCollation(supportsTrimCollation = true)) - override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def prettyName: String = "get_json_object" @@ -382,6 +383,7 @@ case class StructsToJson( with RuntimeReplaceable with ExpectsInputTypes with TimeZoneAwareExpression + with DefaultStringProducingExpression with QueryErrorsBase { override def nullable: Boolean = true @@ -401,8 +403,6 @@ case class StructsToJson( @transient private lazy val inputSchema = child.dataType - override def dataType: DataType = SQLConf.get.defaultStringType - override def checkInputDataTypes(): TypeCheckResult = inputSchema match { case dt @ (_: StructType | _: MapType | _: ArrayType | _: VariantType) => JacksonUtils.verifyType(prettyName, dt) @@ -453,6 +453,7 @@ case class SchemaOfJson( options: Map[String, String]) extends UnaryExpression with RuntimeReplaceable + with DefaultStringProducingExpression with QueryErrorsBase { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -461,8 +462,6 @@ case class SchemaOfJson( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = SQLConf.get.defaultStringType - override def nullable: Boolean = false @transient @@ -573,11 +572,12 @@ case class LengthOfJsonArray(child: Expression) case class JsonObjectKeys(child: Expression) extends UnaryExpression with ExpectsInputTypes - with RuntimeReplaceable { + with RuntimeReplaceable + with DefaultStringProducingExpression { override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) - override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) + override def dataType: DataType = ArrayType(super.dataType) override def nullable: Boolean = true override def prettyName: String = "json_object_keys" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 317a08b8c64c6..6233f4613b343 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1005,10 +1005,12 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia group = "math_funcs") // scalastyle:on line.size.limit case class Bin(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with Serializable { + extends UnaryExpression + with ImplicitCastInputTypes + with Serializable + with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true override def inputTypes: Seq[DataType] = Seq(LongType) - override def dataType: DataType = SQLConf.get.defaultStringType protected override def nullSafeEval(input: Any): Any = UTF8String.toBinaryString(input.asInstanceOf[Long]) @@ -1114,7 +1116,9 @@ object Hex { since = "1.5.0", group = "math_funcs") case class Hex(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { + extends UnaryExpression + with ImplicitCastInputTypes + with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = @@ -1122,7 +1126,7 @@ case class Hex(child: Expression) override def dataType: DataType = child.dataType match { case st: StringType => st - case _ => SQLConf.get.defaultStringType + case _ => super.dataType } protected override def nullSafeEval(num: Any): Any = child.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index fb30eab327d4c..cd5aedb9bb891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -204,8 +204,10 @@ object AssertTrue { """, since = "1.6.0", group = "misc_funcs") -case class CurrentDatabase() extends LeafExpression with Unevaluable { - override def dataType: DataType = SQLConf.get.defaultStringType +case class CurrentDatabase() + extends LeafExpression + with DefaultStringProducingExpression + with Unevaluable { override def nullable: Boolean = false override def prettyName: String = "current_schema" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) @@ -223,8 +225,10 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { """, since = "3.1.0", group = "misc_funcs") -case class CurrentCatalog() extends LeafExpression with Unevaluable { - override def dataType: DataType = SQLConf.get.defaultStringType +case class CurrentCatalog() + extends LeafExpression + with DefaultStringProducingExpression + with Unevaluable { override def nullable: Boolean = false override def prettyName: String = "current_catalog" final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) @@ -245,7 +249,8 @@ case class CurrentCatalog() extends LeafExpression with Unevaluable { group = "misc_funcs") // scalastyle:on line.size.limit case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Nondeterministic - with ExpressionWithRandomSeed { + with DefaultStringProducingExpression + with ExpressionWithRandomSeed { def this() = this(None) @@ -259,8 +264,6 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non override def nullable: Boolean = false - override def dataType: DataType = SQLConf.get.defaultStringType - override def stateful: Boolean = true @transient private[this] var randomGenerator: RandomUUIDGenerator = _ @@ -295,12 +298,15 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non since = "3.0.0", group = "misc_funcs") // scalastyle:on line.size.limit -case class SparkVersion() extends LeafExpression with RuntimeReplaceable { +case class SparkVersion() + extends LeafExpression + with RuntimeReplaceable + with DefaultStringProducingExpression { override def prettyName: String = "version" override lazy val replacement: Expression = StaticInvoke( classOf[ExpressionImplUtils], - SQLConf.get.defaultStringType, + dataType, "getSparkVersion", returnNullable = false) } @@ -316,10 +322,9 @@ case class SparkVersion() extends LeafExpression with RuntimeReplaceable { """, since = "3.0.0", group = "misc_funcs") -case class TypeOf(child: Expression) extends UnaryExpression { +case class TypeOf(child: Expression) extends UnaryExpression with DefaultStringProducingExpression { override def nullable: Boolean = false override def foldable: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType override def eval(input: InternalRow): Any = UTF8String.fromString(child.dataType.catalogString) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -340,9 +345,11 @@ case class TypeOf(child: Expression) extends UnaryExpression { since = "3.2.0", group = "misc_funcs") // scalastyle:on line.size.limit -case class CurrentUser() extends LeafExpression with Unevaluable { +case class CurrentUser() + extends LeafExpression + with DefaultStringProducingExpression + with Unevaluable { override def nullable: Boolean = false - override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_user") final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index fd6399d65271e..21dcbba818d9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.util.ToNumberParser import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -276,7 +275,7 @@ object ToCharacterBuilder extends ExpressionBuilder { } case class ToCharacter(left: Expression, right: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true private lazy val numberFormatter = { @@ -288,7 +287,6 @@ case class ToCharacter(left: Expression, right: Expression) } } - override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringTypeWithCollation(supportsTrimCollation = true)) override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 50c699ef69bd6..fa6eb2c111895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -313,15 +313,17 @@ object Uniform { group = "string_funcs") case class RandStr( length: Expression, override val seedExpression: Expression, hideSeed: Boolean) - extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic - with ExpectsInputTypes { + extends ExpressionWithRandomSeed + with BinaryLike[Expression] + with DefaultStringProducingExpression + with Nondeterministic + with ExpectsInputTypes { def this(length: Expression) = this(length, UnresolvedSeed, hideSeed = true) def this(length: Expression, seedExpression: Expression) = this(length, seedExpression, hideSeed = false) override def nullable: Boolean = false - override def dataType: DataType = StringType override def stateful: Boolean = true override def left: Expression = length override def right: Expression = seedExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index efd7e5c07de40..b90537daabd68 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2232,9 +2232,8 @@ case class StringRepeat(str: Expression, times: Expression) since = "1.5.0", group = "string_funcs") case class StringSpace(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(IntegerType) override def nullSafeEval(s: Any): Any = { @@ -2717,11 +2716,9 @@ case class Levenshtein( since = "1.5.0", group = "string_funcs") case class SoundEx(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCollation(supportsTrimCollation = true)) @@ -2798,10 +2795,9 @@ case class Ascii(child: Expression) group = "string_funcs") // scalastyle:on line.size.limit case class Chr(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes { + extends UnaryExpression with ImplicitCastInputTypes with DefaultStringProducingExpression { override def nullIntolerant: Boolean = true - override def dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(LongType) protected override def nullSafeEval(lon: Any): Any = { @@ -2848,11 +2844,13 @@ case class Chr(child: Expression) since = "1.5.0", group = "string_funcs") case class Base64(child: Expression, chunkBase64: Boolean) - extends UnaryExpression with RuntimeReplaceable with ImplicitCastInputTypes { + extends UnaryExpression + with RuntimeReplaceable + with ImplicitCastInputTypes + with DefaultStringProducingExpression { def this(expr: Expression) = this(expr, SQLConf.get.chunkBase64StringEnabled) - override val dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[DataType] = Seq(BinaryType) override lazy val replacement: Expression = StaticInvoke( @@ -3075,12 +3073,11 @@ case class StringDecode( charset: Expression, legacyCharsets: Boolean, legacyErrorAction: Boolean) - extends RuntimeReplaceable with ImplicitCastInputTypes { + extends RuntimeReplaceable with ImplicitCastInputTypes with DefaultStringProducingExpression { def this(bin: Expression, charset: Expression) = this(bin, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction) - override val dataType: DataType = SQLConf.get.defaultStringType override def inputTypes: Seq[AbstractDataType] = Seq( BinaryType, StringTypeWithCollation(supportsTrimCollation = true) @@ -3090,7 +3087,7 @@ case class StringDecode( override lazy val replacement: Expression = StaticInvoke( classOf[StringDecode], - SQLConf.get.defaultStringType, + dataType, "decode", Seq(bin, charset, Literal(legacyCharsets), Literal(legacyErrorAction)), Seq( @@ -3337,11 +3334,10 @@ case class ToBinary( since = "1.5.0", group = "string_funcs") case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with DefaultStringProducingExpression { override def left: Expression = x override def right: Expression = d - override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def nullIntolerant: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 845ca0b608ef3..b51cb74a5a8fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, ObjectType} +import org.apache.spark.sql.types.{AbstractDataType, BooleanType, ObjectType} import org.apache.spark.unsafe.types.UTF8String // scalastyle:off line.size.limit @@ -49,12 +49,15 @@ import org.apache.spark.unsafe.types.UTF8String group = "url_funcs") // scalastyle:on line.size.limit case class UrlEncode(child: Expression) - extends RuntimeReplaceable with UnaryLike[Expression] with ImplicitCastInputTypes { + extends RuntimeReplaceable + with UnaryLike[Expression] + with ImplicitCastInputTypes + with DefaultStringProducingExpression { override lazy val replacement: Expression = StaticInvoke( UrlCodec.getClass, - SQLConf.get.defaultStringType, + dataType, "encode", Seq(child), Seq(StringTypeWithCollation(supportsTrimCollation = true))) @@ -87,14 +90,17 @@ case class UrlEncode(child: Expression) group = "url_funcs") // scalastyle:on line.size.limit case class UrlDecode(child: Expression, failOnError: Boolean = true) - extends RuntimeReplaceable with UnaryLike[Expression] with ImplicitCastInputTypes { + extends RuntimeReplaceable + with UnaryLike[Expression] + with ImplicitCastInputTypes + with DefaultStringProducingExpression { def this(child: Expression) = this(child, true) override lazy val replacement: Expression = StaticInvoke( UrlCodec.getClass, - SQLConf.get.defaultStringType, + dataType, "decode", Seq(child, Literal(failOnError)), Seq(StringTypeWithCollation(supportsTrimCollation = true), BooleanType)) @@ -207,14 +213,14 @@ case class ParseUrl( failOnError: Boolean = SQLConf.get.ansiEnabled) extends Expression with ExpectsInputTypes - with RuntimeReplaceable { + with RuntimeReplaceable + with DefaultStringProducingExpression { def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeWithCollation(supportsTrimCollation = true)) - override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index ff8b168793b5d..f722329097bc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -748,10 +748,11 @@ object VariantExplode { case class SchemaOfVariant(child: Expression) extends UnaryExpression with RuntimeReplaceable + with DefaultStringProducingExpression with ExpectsInputTypes { override lazy val replacement: Expression = StaticInvoke( SchemaOfVariant.getClass, - SQLConf.get.defaultStringType, + dataType, "schemaOfVariant", Seq(child), inputTypes, @@ -759,8 +760,6 @@ case class SchemaOfVariant(child: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) - override def dataType: DataType = SQLConf.get.defaultStringType - override def prettyName: String = "schema_of_variant" override protected def withNewChildInternal(newChild: Expression): SchemaOfVariant = @@ -859,13 +858,12 @@ case class SchemaOfVariantAgg( extends TypedImperativeAggregate[DataType] with ExpectsInputTypes with QueryErrorsBase + with DefaultStringProducingExpression with UnaryLike[Expression] { def this(child: Expression) = this(child, 0, 0) override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) - override def dataType: DataType = SQLConf.get.defaultStringType - override def nullable: Boolean = false override def createAggregationBuffer(): DataType = NullType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 2e591288a21cf..800b38ea32223 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -219,12 +218,13 @@ case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract { since = "2.0.0", group = "xml_funcs") // scalastyle:on line.size.limit -case class XPathString(xml: Expression, path: Expression) extends XPathExtract { +case class XPathString(xml: Expression, path: Expression) + extends XPathExtract + with DefaultStringProducingExpression { @transient override lazy val evaluator: XPathEvaluator = XPathStringEvaluator(pathUTF8String) override def prettyName: String = "xpath_string" - override def dataType: DataType = SQLConf.get.defaultStringType override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = copy(xml = newLeft, path = newRight) @@ -243,12 +243,14 @@ case class XPathString(xml: Expression, path: Expression) extends XPathExtract { since = "2.0.0", group = "xml_funcs") // scalastyle:on line.size.limit -case class XPathList(xml: Expression, path: Expression) extends XPathExtract { +case class XPathList(xml: Expression, path: Expression) + extends XPathExtract + with DefaultStringProducingExpression { @transient override lazy val evaluator: XPathEvaluator = XPathListEvaluator(pathUTF8String) override def prettyName: String = "xpath" - override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) + override def dataType: DataType = ArrayType(super.dataType) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): XPathList = copy(xml = newLeft, path = newRight) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index d8254f04b4d94..25a054f79c368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -154,6 +154,7 @@ case class SchemaOfXml( options: Map[String, String]) extends UnaryExpression with RuntimeReplaceable + with DefaultStringProducingExpression with QueryErrorsBase { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -162,8 +163,6 @@ case class SchemaOfXml( child = child, options = ExprUtils.convertToMapData(options)) - override def dataType: DataType = SQLConf.get.defaultStringType - override def nullable: Boolean = false @transient @@ -240,6 +239,7 @@ case class StructsToXml( timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression + with DefaultStringProducingExpression with ExpectsInputTypes { override def nullable: Boolean = true override def nullIntolerant: Boolean = true @@ -294,8 +294,6 @@ case class StructsToXml( getAndReset() } - override def dataType: DataType = SQLConf.get.defaultStringType - override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId))