diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
new file mode 100644
index 0000000000000..40bd1c7abc75f
--- /dev/null
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.types;
+
+import org.apache.spark.unsafe.Platform;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+/**
+ * The physical data representation of {@link org.apache.spark.sql.types.VariantType} that
+ * represents a semi-structured value. It consists of two binary values: {@link VariantVal#value}
+ * and {@link VariantVal#metadata}. The value encodes types and values, but not field names. The
+ * metadata currently contains a version flag and a list of field names. We can extend/modify the
+ * detailed binary format given the version flag.
+ *
+ * A {@link VariantVal} can be produced by casting another value into the Variant type or parsing a
+ * JSON string in the {@link org.apache.spark.sql.catalyst.expressions.variant.ParseJson}
+ * expression. We can extract a path consisting of field names and array indices from it, cast it
+ * into a concrete data type, or rebuild a JSON string from it.
+ *
+ * The storage layout of this class in {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow}
+ * and {@link org.apache.spark.sql.catalyst.expressions.UnsafeArrayData} is: the fixed-size part is
+ * a long value "offsetAndSize". The upper 32 bits is the offset that points to the start position
+ * of the actual binary content. The lower 32 bits is the total length of the binary content. The
+ * binary content contains: 4 bytes representing the length of {@link VariantVal#value}, content of
+ * {@link VariantVal#value}, content of {@link VariantVal#metadata}. This is an internal and
+ * transient format and can be modified at any time.
+ */
+public class VariantVal implements Serializable {
+ protected final byte[] value;
+ protected final byte[] metadata;
+
+ public VariantVal(byte[] value, byte[] metadata) {
+ this.value = value;
+ this.metadata = metadata;
+ }
+
+ public byte[] getValue() {
+ return value;
+ }
+
+ public byte[] getMetadata() {
+ return metadata;
+ }
+
+ /**
+ * This function reads the binary content described in `writeIntoUnsafeRow` from `baseObject`. The
+ * offset is computed by adding the offset in {@code offsetAndSize} and {@code baseOffset}.
+ */
+ public static VariantVal readFromUnsafeRow(
+ long offsetAndSize,
+ Object baseObject,
+ long baseOffset) {
+ // offset and totalSize is the upper/lower 32 bits in offsetAndSize.
+ int offset = (int) (offsetAndSize >> 32);
+ int totalSize = (int) offsetAndSize;
+ int valueSize = Platform.getInt(baseObject, baseOffset + offset);
+ int metadataSize = totalSize - 4 - valueSize;
+ byte[] value = new byte[valueSize];
+ byte[] metadata = new byte[metadataSize];
+ Platform.copyMemory(
+ baseObject,
+ baseOffset + offset + 4,
+ value,
+ Platform.BYTE_ARRAY_OFFSET,
+ valueSize
+ );
+ Platform.copyMemory(
+ baseObject,
+ baseOffset + offset + 4 + valueSize,
+ metadata,
+ Platform.BYTE_ARRAY_OFFSET,
+ metadataSize
+ );
+ return new VariantVal(value, metadata);
+ }
+
+ public String debugString() {
+ return "VariantVal{" +
+ "value=" + Arrays.toString(value) +
+ ", metadata=" + Arrays.toString(metadata) +
+ '}';
+ }
+
+ /**
+ * @return A human-readable representation of the Variant value. It is always a JSON string at
+ * this moment.
+ */
+ @Override
+ public String toString() {
+ // NOTE: the encoding is not yet implemented, this is not the final implementation.
+ return new String(value);
+ }
+}
diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json
index e3b9f3161b24d..1b4c10acaf7bc 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -317,6 +317,12 @@
],
"sqlState" : "58030"
},
+ "CANNOT_SAVE_VARIANT" : {
+ "message" : [
+ "Cannot save variant data type into external storage."
+ ],
+ "sqlState" : "0A000"
+ },
"CANNOT_UPDATE_FIELD" : {
"message" : [
"Cannot update
field type:"
diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
index 0e27e4a604c46..e235c13d413e2 100644
--- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
+++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala
@@ -72,6 +72,8 @@ private[sql] object AvroUtils extends Logging {
}
def supportsDataType(dataType: DataType): Boolean = dataType match {
+ case _: VariantType => false
+
case _: AtomicType => true
case st: StructType => st.forall { f => supportsDataType(f.dataType) }
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index a811019e0a57b..ee9c2fd67b307 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -280,6 +280,12 @@ SQLSTATE: 58030
Failed to set permissions on created path `` back to ``.
+### CANNOT_SAVE_VARIANT
+
+[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported)
+
+Cannot save variant data type into external storage.
+
### [CANNOT_UPDATE_FIELD](sql-error-conditions-cannot-update-field-error-class.html)
[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported)
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index 09c38a0099599..4729db16d63f3 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -671,6 +671,7 @@ Below is a list of all the keywords in Spark SQL.
|VARCHAR|non-reserved|non-reserved|reserved|
|VAR|non-reserved|non-reserved|non-reserved|
|VARIABLE|non-reserved|non-reserved|non-reserved|
+|VARIANT|non-reserved|non-reserved|reserved|
|VERSION|non-reserved|non-reserved|non-reserved|
|VIEW|non-reserved|non-reserved|non-reserved|
|VIEWS|non-reserved|non-reserved|non-reserved|
diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index e8b5cb012fcae..9b3dcbc6d194f 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -408,6 +408,7 @@ VALUES: 'VALUES';
VARCHAR: 'VARCHAR';
VAR: 'VAR';
VARIABLE: 'VARIABLE';
+VARIANT: 'VARIANT';
VERSION: 'VERSION';
VIEW: 'VIEW';
VIEWS: 'VIEWS';
diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index bd449a4e194e8..609bd72e21935 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -1086,6 +1086,7 @@ type
| DECIMAL | DEC | NUMERIC
| VOID
| INTERVAL
+ | VARIANT
| ARRAY | STRUCT | MAP
| unsupportedType=identifier
;
@@ -1545,6 +1546,7 @@ ansiNonReserved
| VARCHAR
| VAR
| VARIABLE
+ | VARIANT
| VERSION
| VIEW
| VIEWS
@@ -1893,6 +1895,7 @@ nonReserved
| VARCHAR
| VAR
| VARIABLE
+ | VARIANT
| VERSION
| VIEW
| VIEWS
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index e5e9ba644b814..9133abce88adc 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -25,7 +25,7 @@ import scala.reflect.{classTag, ClassTag}
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
import org.apache.spark.util.SparkClassUtils
/**
@@ -216,6 +216,7 @@ object AgnosticEncoders {
case object CalendarIntervalEncoder extends LeafEncoder[CalendarInterval](CalendarIntervalType)
case object DayTimeIntervalEncoder extends LeafEncoder[Duration](DayTimeIntervalType())
case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType())
+ case object VariantEncoder extends LeafEncoder[VariantVal](VariantType)
case class DateEncoder(override val lenientSerialization: Boolean)
extends LeafEncoder[jsql.Date](DateType)
case class LocalDateEncoder(override val lenientSerialization: Boolean)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 69661c343c5b1..a201da9c95c9e 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.reflect.classTag
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VariantEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types._
@@ -90,6 +90,7 @@ object RowEncoder {
case CalendarIntervalType => CalendarIntervalEncoder
case _: DayTimeIntervalType => DayTimeIntervalEncoder
case _: YearMonthIntervalType => YearMonthIntervalEncoder
+ case _: VariantType => VariantEncoder
case p: PythonUserDefinedType =>
// TODO check if this works.
encoderForDataType(p.sqlType, lenient)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
index b30c6fa29e829..3a2e704ffe9f7 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin}
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.internal.SqlApiConf
-import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, YearMonthIntervalType}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType}
class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
protected def typedVisit[T](ctx: ParseTree): T = {
@@ -82,6 +82,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
DecimalType(precision.getText.toInt, scale.getText.toInt)
case (VOID, Nil) => NullType
case (INTERVAL, Nil) => CalendarIntervalType
+ case (VARIANT, Nil) => VariantType
case (CHARACTER | CHAR | VARCHAR, Nil) =>
throw QueryParsingErrors.charTypeMissingLengthError(ctx.`type`.getText, ctx)
case (ARRAY | STRUCT | MAP, Nil) =>
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 5f563e3b7a8f1..94252de48d1ea 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -173,7 +173,8 @@ object DataType {
YearMonthIntervalType(YEAR),
YearMonthIntervalType(MONTH),
YearMonthIntervalType(YEAR, MONTH),
- TimestampNTZType)
+ TimestampNTZType,
+ VariantType)
.map(t => t.typeName -> t).toMap
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala
new file mode 100644
index 0000000000000..103fe7a59fc83
--- /dev/null
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.types
+
+import org.apache.spark.annotation.Unstable
+
+/**
+ * The data type representing semi-structured values with arbitrary hierarchical data structures. It
+ * is intended to store parsed JSON values and most other data types in the system (e.g., it cannot
+ * store a map with a non-string key type).
+ *
+ * @since 4.0.0
+ */
+@Unstable
+class VariantType private () extends AtomicType {
+ // The default size is used in query planning to drive optimization decisions. 2048 is arbitrarily
+ // picked and we currently don't have any data to support it. This may need revisiting later.
+ override def defaultSize: Int = 2048
+
+ /** This is a no-op because values with VARIANT type are always nullable. */
+ private[spark] override def asNullable: VariantType = this
+}
+
+/**
+ * @since 4.0.0
+ */
+@Unstable
+case object VariantType extends VariantType
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
index e7af5a68b4663..ffc3c8eaf8f84 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
@@ -45,7 +45,7 @@ public class ExpressionInfo {
"collection_funcs", "predicate_funcs", "conditional_funcs", "conversion_funcs",
"csv_funcs", "datetime_funcs", "generator_funcs", "hash_funcs", "json_funcs",
"lambda_funcs", "map_funcs", "math_funcs", "misc_funcs", "string_funcs", "struct_funcs",
- "window_funcs", "xml_funcs", "table_funcs", "url_funcs"));
+ "window_funcs", "xml_funcs", "table_funcs", "url_funcs", "variant_funcs"));
private static final Set validSources =
new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf",
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
index eea7149d02594..b88a892db4b46 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java
@@ -24,6 +24,7 @@
import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
public interface SpecializedGetters {
@@ -51,6 +52,8 @@ public interface SpecializedGetters {
CalendarInterval getInterval(int ordinal);
+ VariantVal getVariant(int ordinal);
+
InternalRow getStruct(int ordinal, int numFields);
ArrayData getArray(int ordinal);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
index 91f04c3d327ac..9e508dbb271cf 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java
@@ -66,6 +66,9 @@ public static Object read(
if (physicalDataType instanceof PhysicalBinaryType) {
return obj.getBinary(ordinal);
}
+ if (physicalDataType instanceof PhysicalVariantType) {
+ return obj.getVariant(ordinal);
+ }
if (physicalDataType instanceof PhysicalStructType) {
return obj.getStruct(ordinal, ((PhysicalStructType) physicalDataType).fields().length);
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index ea6f1e05422b5..700e42cb843c8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -38,6 +38,7 @@
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
@@ -231,6 +232,12 @@ public CalendarInterval getInterval(int ordinal) {
return new CalendarInterval(months, days, microseconds);
}
+ @Override
+ public VariantVal getVariant(int ordinal) {
+ if (isNullAt(ordinal)) return null;
+ return VariantVal.readFromUnsafeRow(getLong(ordinal), baseObject, baseOffset);
+ }
+
@Override
public UnsafeRow getStruct(int ordinal, int numFields) {
if (isNullAt(ordinal)) return null;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 8f9d5919e1d9f..fca45c58beed0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -36,6 +36,7 @@
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;
@@ -417,6 +418,12 @@ public CalendarInterval getInterval(int ordinal) {
}
}
+ @Override
+ public VariantVal getVariant(int ordinal) {
+ if (isNullAt(ordinal)) return null;
+ return VariantVal.readFromUnsafeRow(getLong(ordinal), baseObject, baseOffset);
+ }
+
@Override
public UnsafeRow getStruct(int ordinal, int numFields) {
if (isNullAt(ordinal)) {
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 8d4e187d01a12..d651e5ab5b3e5 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -25,6 +25,7 @@
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
/**
* Base class for writing Unsafe* structures.
@@ -149,6 +150,27 @@ public void write(int ordinal, CalendarInterval input) {
increaseCursor(16);
}
+ public void write(int ordinal, VariantVal input) {
+ // See the class comment of VariantVal for the format of the binary content.
+ byte[] value = input.getValue();
+ byte[] metadata = input.getMetadata();
+ int totalSize = 4 + value.length + metadata.length;
+ int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSize);
+ grow(roundedSize);
+ zeroOutPaddingBytes(totalSize);
+ Platform.putInt(getBuffer(), cursor(), value.length);
+ Platform.copyMemory(value, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor() + 4, value.length);
+ Platform.copyMemory(
+ metadata,
+ Platform.BYTE_ARRAY_OFFSET,
+ getBuffer(),
+ cursor() + 4 + value.length,
+ metadata.length
+ );
+ setOffsetAndSize(ordinal, totalSize);
+ increaseCursor(roundedSize);
+ }
+
public final void write(int ordinal, UnsafeRow row) {
writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes());
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
index 73c2cf2cc05f8..cd3c30fa69335 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
@@ -22,6 +22,7 @@
import org.apache.spark.sql.types.UserDefinedType;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
/**
* An interface representing in-memory columnar data in Spark. This interface defines the main APIs
@@ -299,6 +300,16 @@ public CalendarInterval getInterval(int rowId) {
return new CalendarInterval(months, days, microseconds);
}
+ /**
+ * Returns the Variant value for {@code rowId}. Similar to {@link #getInterval(int)}, the
+ * implementation must implement {@link #getChild(int)} and define 2 child vectors of binary type
+ * for the Variant value and metadata.
+ */
+ public final VariantVal getVariant(int rowId) {
+ if (isNullAt(rowId)) return null;
+ return new VariantVal(getChild(0).getBinary(rowId), getChild(1).getBinary(rowId));
+ }
+
/**
* @return child {@link ColumnVector} at the given ordinal.
*/
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
index bd7c3d7c0fd49..e0141a575b299 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
@@ -24,6 +24,7 @@
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
/**
* Array abstraction in {@link ColumnVector}.
@@ -160,6 +161,11 @@ public CalendarInterval getInterval(int ordinal) {
return data.getInterval(offset + ordinal);
}
+ @Override
+ public VariantVal getVariant(int ordinal) {
+ return data.getVariant(offset + ordinal);
+ }
+
@Override
public ColumnarRow getStruct(int ordinal, int numFields) {
return data.getStruct(offset + ordinal);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
index c0d2ae8e7d0e8..ac23f70584e89 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
@@ -23,6 +23,7 @@
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
/**
* This class wraps an array of {@link ColumnVector} and provides a row view.
@@ -133,6 +134,11 @@ public CalendarInterval getInterval(int ordinal) {
return columns[ordinal].getInterval(rowId);
}
+ @Override
+ public VariantVal getVariant(int ordinal) {
+ return columns[ordinal].getVariant(rowId);
+ }
+
@Override
public ColumnarRow getStruct(int ordinal, int numFields) {
return columns[ordinal].getStruct(rowId);
@@ -182,6 +188,8 @@ public Object get(int ordinal, DataType dataType) {
return getStruct(ordinal, ((StructType)dataType).fields().length);
} else if (dataType instanceof MapType) {
return getMap(ordinal);
+ } else if (dataType instanceof VariantType) {
+ return getVariant(ordinal);
} else {
throw new UnsupportedOperationException("Datatype not supported " + dataType);
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
index 1df4653f55276..18f6779cccb96 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
@@ -23,6 +23,7 @@
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
/**
* Row abstraction in {@link ColumnVector}.
@@ -140,6 +141,11 @@ public CalendarInterval getInterval(int ordinal) {
return data.getChild(ordinal).getInterval(rowId);
}
+ @Override
+ public VariantVal getVariant(int ordinal) {
+ return data.getChild(ordinal).getVariant(rowId);
+ }
+
@Override
public ColumnarRow getStruct(int ordinal, int numFields) {
return data.getChild(ordinal).getStruct(rowId);
@@ -187,6 +193,8 @@ public Object get(int ordinal, DataType dataType) {
return getStruct(ordinal, ((StructType)dataType).fields().length);
} else if (dataType instanceof MapType) {
return getMap(ordinal);
+ } else if (dataType instanceof VariantType) {
+ return getVariant(ordinal);
} else {
throw new UnsupportedOperationException("Datatype not supported " + dataType);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
index 429ce805bf2c4..034b959c5a383 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StructType}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
/**
* An [[InternalRow]] that projects particular columns from another [[InternalRow]] without copying
@@ -99,6 +99,10 @@ case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) exte
row.getInterval(colOrdinals(ordinal))
}
+ override def getVariant(ordinal: Int): VariantVal = {
+ row.getVariant(colOrdinals(ordinal))
+ }
+
override def getStruct(ordinal: Int, numFields: Int): InternalRow = {
row.getStruct(colOrdinals(ordinal), numFields)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 23d63011db53f..4fb8d88f6eab1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.variant._
import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -808,6 +809,9 @@ object FunctionRegistry {
expression[LengthOfJsonArray]("json_array_length"),
expression[JsonObjectKeys]("json_object_keys"),
+ // Variant
+ expression[ParseJson]("parse_json"),
+
// cast
expression[Cast]("cast"),
// Cast aliases (SPARK-16730)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
index 4540ecffe0d21..793dd373d6899 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.encoders
import scala.collection.Map
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalIntegerType, PhysicalLongType}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
@@ -68,6 +68,7 @@ object EncoderUtils {
case CalendarIntervalEncoder => true
case BinaryEncoder => true
case _: SparkDecimalEncoder => true
+ case VariantEncoder => true
case _ => false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
index 6aa5fefc73902..50408b41c1a76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala
@@ -162,6 +162,8 @@ object InterpretedUnsafeProjection {
case PhysicalStringType => (v, i) => writer.write(i, v.getUTF8String(i))
+ case PhysicalVariantType => (v, i) => writer.write(i, v.getVariant(i))
+
case PhysicalStructType(fields) =>
val numFields = fields.length
val rowWriter = new UnsafeRowWriter(writer, numFields)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
index 86871223d66ad..345f2b3030b58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
/**
* A mutable wrapper that makes two rows appear as a single concatenated row. Designed to
@@ -120,6 +120,9 @@ class JoinedRow extends InternalRow {
override def getInterval(i: Int): CalendarInterval =
if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields)
+ override def getVariant(i: Int): VariantVal =
+ if (i < row1.numFields) row1.getVariant(i) else row2.getVariant(i - row1.numFields)
+
override def getMap(i: Int): MapData =
if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 3595e43fcb987..4c32f682c275f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -1475,6 +1475,7 @@ object CodeGenerator extends Logging {
classOf[UTF8String].getName,
classOf[Decimal].getName,
classOf[CalendarInterval].getName,
+ classOf[VariantVal].getName,
classOf[ArrayData].getName,
classOf[UnsafeArrayData].getName,
classOf[MapData].getName,
@@ -1641,6 +1642,7 @@ object CodeGenerator extends Logging {
case PhysicalNullType => "null"
case PhysicalStringType => s"$input.getUTF8String($ordinal)"
case t: PhysicalStructType => s"$input.getStruct($ordinal, ${t.fields.size})"
+ case PhysicalVariantType => s"$input.getVariant($ordinal)"
case _ => s"($jt)$input.get($ordinal, null)"
}
}
@@ -1928,6 +1930,7 @@ object CodeGenerator extends Logging {
case PhysicalShortType => JAVA_SHORT
case PhysicalStringType => "UTF8String"
case _: PhysicalStructType => "InternalRow"
+ case _: PhysicalVariantType => "VariantVal"
case _ => "Object"
}
}
@@ -1951,6 +1954,7 @@ object CodeGenerator extends Logging {
case _: MapType => classOf[MapData]
case udt: UserDefinedType[_] => javaClass(udt.sqlType)
case ObjectType(cls) => cls
+ case VariantType => classOf[VariantVal]
case _ => classOf[Object]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 217ed562db779..c406ba0707b3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -97,6 +97,7 @@ object Literal {
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
Literal(convert(a), dataType)
case i: CalendarInterval => Literal(i, CalendarIntervalType)
+ case v: VariantVal => Literal(v, VariantType)
case null => Literal(null, NullType)
case v: Literal => v
case _ =>
@@ -143,6 +144,7 @@ object Literal {
case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT
case _ if clz == classOf[CalendarInterval] => CalendarIntervalType
+ case _ if clz == classOf[VariantVal] => VariantType
case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType))
@@ -235,6 +237,7 @@ object Literal {
case PhysicalNullType => true
case PhysicalShortType => v.isInstanceOf[Short]
case PhysicalStringType => v.isInstanceOf[UTF8String]
+ case PhysicalVariantType => v.isInstanceOf[VariantVal]
case st: PhysicalStructType =>
v.isInstanceOf[InternalRow] && {
val row = v.asInstanceOf[InternalRow]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 296d093a13de6..8379069c53d9f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -47,6 +47,7 @@ trait BaseGenericInternalRow extends InternalRow {
override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+ override def getVariant(ordinal: Int): VariantVal = getAs(ordinal)
override def getMap(ordinal: Int): MapData = getAs(ordinal)
override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
new file mode 100644
index 0000000000000..136ae4a3ef436
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.variant
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types._
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(jsonStr) - Parse a JSON string as an Variant value. Throw an exception when the string is not valid JSON value.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('{"a":1,"b":0.8}');
+ {"a":1,"b":0.8}
+ """,
+ since = "4.0.0",
+ group = "variant_funcs"
+)
+// scalastyle:on line.size.limit
+case class ParseJson(child: Expression) extends UnaryExpression
+ with NullIntolerant with ExpectsInputTypes with CodegenFallback {
+ override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
+
+ override def dataType: DataType = VariantType
+
+ override def prettyName: String = "parse_json"
+
+ protected override def nullSafeEval(input: Any): Any = {
+ // A dummy implementation: the value is the raw bytes of the input string. This is not the final
+ // implementation, but only intended for debugging.
+ // TODO(SPARK-45891): Have an actual parse_json implementation.
+ new VariantVal(input.asInstanceOf[UTF8String].toString.getBytes, Array())
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): ParseJson =
+ copy(child = newChild)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
index 29d7a39ace3c1..290a35eb8e3b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
@@ -23,8 +23,8 @@ import scala.reflect.runtime.universe.typeTag
import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, InterpretedOrdering, SortOrder}
import org.apache.spark.sql.catalyst.util.{ArrayData, SQLOrderingUtil}
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, YearMonthIntervalType}
-import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType}
+import org.apache.spark.unsafe.types.{ByteArray, UTF8String, VariantVal}
import org.apache.spark.util.ArrayImplicits._
sealed abstract class PhysicalDataType {
@@ -58,6 +58,7 @@ object PhysicalDataType {
case StructType(fields) => PhysicalStructType(fields)
case MapType(keyType, valueType, valueContainsNull) =>
PhysicalMapType(keyType, valueType, valueContainsNull)
+ case VariantType => PhysicalVariantType
case _ => UninitializedPhysicalType
}
@@ -327,6 +328,18 @@ case class PhysicalStructType(fields: Array[StructField]) extends PhysicalDataTy
}
}
+class PhysicalVariantType extends PhysicalDataType {
+ private[sql] type InternalType = VariantVal
+ @transient private[sql] lazy val tag = typeTag[InternalType]
+
+ // TODO(SPARK-45891): Support comparison for the Variant type.
+ override private[sql] def ordering =
+ throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError(
+ "PhysicalVariantType")
+}
+
+object PhysicalVariantType extends PhysicalVariantType
+
object UninitializedPhysicalType extends PhysicalDataType {
override private[sql] def ordering =
throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index bdf8d36321e64..7ff36bef5a4b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{DataType, Decimal}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
class GenericArrayData(val array: Array[Any]) extends ArrayData {
@@ -73,6 +73,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal)
override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal)
override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal)
+ override def getVariant(ordinal: Int): VariantVal = getAs(ordinal)
override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
override def getArray(ordinal: Int): ArrayData = getAs(ordinal)
override def getMap(ordinal: Int): MapData = getAs(ordinal)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index c3249a4c02d8c..9cc99e9bfa335 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -1564,6 +1564,12 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map.empty)
}
+ def cannotSaveVariantIntoExternalStorageError(): Throwable = {
+ new AnalysisException(
+ errorClass = "CANNOT_SAVE_VARIANT",
+ messageParameters = Map.empty)
+ }
+
def cannotResolveAttributeError(name: String, outputStr: String): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1137",
diff --git a/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt b/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt
index 921491a4a4761..47a3f02ac1656 100644
--- a/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt
+++ b/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt
@@ -355,6 +355,7 @@ VAR_POP
VAR_SAMP
VARBINARY
VARCHAR
+VARIANT
VARYING
VERSIONING
WHEN
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
index 01aa3579aea98..eeb05139a3e5b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
class GenerateUnsafeProjectionSuite extends SparkFunSuite {
test("Test unsafe projection string access pattern") {
@@ -88,6 +88,7 @@ object AlwaysNull extends InternalRow {
override def getUTF8String(ordinal: Int): UTF8String = notSupported
override def getBinary(ordinal: Int): Array[Byte] = notSupported
override def getInterval(ordinal: Int): CalendarInterval = notSupported
+ override def getVariant(ordinal: Int): VariantVal = notSupported
override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
override def getArray(ordinal: Int): ArrayData = notSupported
override def getMap(ordinal: Int): MapData = notSupported
@@ -117,6 +118,7 @@ object AlwaysNonNull extends InternalRow {
override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test")
override def getBinary(ordinal: Int): Array[Byte] = notSupported
override def getInterval(ordinal: Int): CalendarInterval = notSupported
+ override def getVariant(ordinal: Int): VariantVal = notSupported
override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported
override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3"))
val keyArray = stringToUTF8Array(Array("1", "2", "3"))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
index eaed279679251..e2a416b773aa9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.Decimal
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
class UnsafeRowWriterSuite extends SparkFunSuite {
@@ -61,4 +61,15 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
rowWriter.write(1, interval)
assert(rowWriter.getRow.getInterval(1) === interval)
}
+
+ test("write and get variant through UnsafeRowWriter") {
+ val rowWriter = new UnsafeRowWriter(2)
+ rowWriter.resetRowWriter()
+ rowWriter.setNullAt(0)
+ assert(rowWriter.getRow.isNullAt(0))
+ assert(rowWriter.getRow.getVariant(0) === null)
+ val variant = new VariantVal(Array[Byte](1, 2, 3), Array[Byte](-1, -2, -3, -4))
+ rowWriter.write(1, variant)
+ assert(rowWriter.getRow.getVariant(1).debugString() == variant.debugString())
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 7b841ab9933e2..29c106651acf0 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -38,6 +38,7 @@
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
/**
* Utilities to help manipulate data associate with ColumnVectors. These should be used mostly
@@ -89,6 +90,8 @@ public static void populate(ConstantColumnVector col, InternalRow row, int field
} else if (pdt instanceof PhysicalCalendarIntervalType) {
// The value of `numRows` is irrelevant.
col.setCalendarInterval((CalendarInterval) row.get(fieldIdx, t));
+ } else if (pdt instanceof PhysicalVariantType) {
+ col.setVariant((VariantVal)row.get(fieldIdx, t));
} else {
throw new RuntimeException(String.format("DataType %s is not supported" +
" in column vectorized reader.", t.sql()));
@@ -124,7 +127,7 @@ public static Map toJavaIntMap(ColumnarMap map) {
private static void appendValue(WritableColumnVector dst, DataType t, Object o) {
if (o == null) {
- if (t instanceof CalendarIntervalType) {
+ if (t instanceof CalendarIntervalType || t instanceof VariantType) {
dst.appendStruct(true);
} else {
dst.appendNull();
@@ -167,6 +170,11 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o)
dst.getChild(0).appendInt(c.months);
dst.getChild(1).appendInt(c.days);
dst.getChild(2).appendLong(c.microseconds);
+ } else if (t instanceof VariantType) {
+ VariantVal v = (VariantVal) o;
+ dst.appendStruct(false);
+ dst.getChild(0).appendByteArray(v.getValue(), 0, v.getValue().length);
+ dst.getChild(1).appendByteArray(v.getMetadata(), 0, v.getMetadata().length);
} else if (t instanceof DateType) {
dst.appendInt(DateTimeUtils.fromJavaDate((Date) o));
} else if (t instanceof TimestampType) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java
index 5095e6b0c9c6b..43854c2300fde 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java
@@ -25,6 +25,7 @@
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
/**
* This class adds the constant support to ColumnVector.
@@ -67,6 +68,10 @@ public ConstantColumnVector(int numRows, DataType type) {
this.childData[0] = new ConstantColumnVector(1, DataTypes.IntegerType);
this.childData[1] = new ConstantColumnVector(1, DataTypes.IntegerType);
this.childData[2] = new ConstantColumnVector(1, DataTypes.LongType);
+ } else if (type instanceof VariantType) {
+ this.childData = new ConstantColumnVector[2];
+ this.childData[0] = new ConstantColumnVector(1, DataTypes.BinaryType);
+ this.childData[1] = new ConstantColumnVector(1, DataTypes.BinaryType);
} else {
this.childData = null;
}
@@ -307,4 +312,12 @@ public void setCalendarInterval(CalendarInterval value) {
this.childData[1].setInt(value.days);
this.childData[2].setLong(value.microseconds);
}
+
+ /**
+ * Sets the Variant `value` for all rows
+ */
+ public void setVariant(VariantVal value) {
+ this.childData[0].setBinary(value.getValue());
+ this.childData[1].setBinary(value.getMetadata());
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
index eda58815f3b3a..0a110a204e04b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
@@ -28,6 +28,7 @@
import org.apache.spark.sql.vectorized.ColumnarRow;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
/**
* A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash
@@ -142,6 +143,11 @@ public CalendarInterval getInterval(int ordinal) {
return columns[ordinal].getInterval(rowId);
}
+ @Override
+ public VariantVal getVariant(int ordinal) {
+ return columns[ordinal].getVariant(rowId);
+ }
+
@Override
public ColumnarRow getStruct(int ordinal, int numFields) {
return columns[ordinal].getStruct(rowId);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index 4c8ceff356595..10907c69c2260 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -694,7 +694,7 @@ public final int appendStruct(boolean isNull) {
putNull(elementsAppended);
elementsAppended++;
for (WritableColumnVector c: childColumns) {
- if (c.type instanceof StructType) {
+ if (c.type instanceof StructType || c.type instanceof VariantType) {
c.appendStruct(true);
} else {
c.appendNull();
@@ -975,6 +975,10 @@ protected WritableColumnVector(int capacity, DataType dataType) {
this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType);
this.childColumns[1] = reserveNewColumn(capacity, DataTypes.IntegerType);
this.childColumns[2] = reserveNewColumn(capacity, DataTypes.LongType);
+ } else if (type instanceof VariantType) {
+ this.childColumns = new WritableColumnVector[2];
+ this.childColumns[0] = reserveNewColumn(capacity, DataTypes.BinaryType);
+ this.childColumns[1] = reserveNewColumn(capacity, DataTypes.BinaryType);
} else {
this.childColumns = null;
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
index 3fec13a7f9ba9..7c117e0cace97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
@@ -267,6 +267,7 @@ private object RowToColumnConverter {
case DoubleType => DoubleConverter
case StringType => StringConverter
case CalendarIntervalType => CalendarConverter
+ case VariantType => VariantConverter
case at: ArrayType => ArrayConverter(getConverterForType(at.elementType, at.containsNull))
case st: StructType => new StructConverter(st.fields.map(
(f) => getConverterForType(f.dataType, f.nullable)))
@@ -346,6 +347,15 @@ private object RowToColumnConverter {
}
}
+ private object VariantConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ val v = row.getVariant(column)
+ cv.appendStruct(false)
+ cv.getChild(0).appendByteArray(v.getValue, 0, v.getValue.length)
+ cv.getChild(1).appendByteArray(v.getMetadata, 0, v.getMetadata.length)
+ }
+ }
+
private case class ArrayConverter(childConverter: TypeConverter) extends TypeConverter {
override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
val values = row.getArray(column)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
index 9811a1d3f33e4..f6b5ba15afbd4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedComm
import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTablesExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -131,6 +131,7 @@ object HiveResult {
HIVE_STYLE,
startField,
endField)
+ case (v: VariantVal, VariantType) => v.toString
case (other, _: UserDefinedType[_]) => other.toString
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index cd295f3b17bd6..835308f3d0248 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -49,7 +49,7 @@ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, Tex
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, StructField, StructType, VariantType}
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.{HadoopFSUtils, ThreadUtils, Utils}
import org.apache.spark.util.ArrayImplicits._
@@ -503,6 +503,7 @@ case class DataSource(
providingInstance() match {
case dataSource: CreatableRelationProvider =>
disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = true)
+ disallowWritingVariant(outputColumns.map(_.dataType))
dataSource.createRelation(
sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data))
case format: FileFormat =>
@@ -524,6 +525,7 @@ case class DataSource(
providingInstance() match {
case dataSource: CreatableRelationProvider =>
disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = true)
+ disallowWritingVariant(data.schema.map(_.dataType))
SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode)
case format: FileFormat =>
disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = false)
@@ -560,6 +562,14 @@ case class DataSource(
throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError()
})
}
+
+ private def disallowWritingVariant(dataTypes: Seq[DataType]): Unit = {
+ dataTypes.foreach { dt =>
+ if (dt.existsRecursively(_.isInstanceOf[VariantType])) {
+ throw QueryCompilationErrors.cannotSaveVariantIntoExternalStorageError()
+ }
+ }
+ }
}
object DataSource extends Logging {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 069ad9562a7d5..32370562003f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -145,6 +145,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat]
override def supportDataType(dataType: DataType): Boolean = dataType match {
+ case _: VariantType => false
+
case _: BinaryType => false
case _: AtomicType => true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 9c6c77a8b9622..7fb6e98fb0468 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -134,6 +134,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat]
override def supportDataType(dataType: DataType): Boolean = dataType match {
+ case _: VariantType => false
+
case _: AtomicType => true
case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index b7e6f11f67d69..623f97499cd55 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -237,6 +237,8 @@ class OrcFileFormat
}
override def supportDataType(dataType: DataType): Boolean = dataType match {
+ case _: VariantType => false
+
case _: AtomicType => true
case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index eedd165278aed..f60f7c11eefa4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -185,6 +185,11 @@ class ParquetToSparkSchemaConverter(
}
field match {
case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType)
+ case groupColumn: GroupColumnIO if targetType.contains(VariantType) =>
+ ParquetColumn(VariantType, groupColumn, Seq(
+ convertField(groupColumn.getChild(0), Some(BinaryType)),
+ convertField(groupColumn.getChild(1), Some(BinaryType))
+ ))
case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType)
}
}
@@ -719,6 +724,12 @@ class SparkToParquetSchemaConverter(
// Other types
// ===========
+ case VariantType =>
+ Types.buildGroup(repetition)
+ .addField(convertField(StructField("value", BinaryType, nullable = false)))
+ .addField(convertField(StructField("metadata", BinaryType, nullable = false)))
+ .named(field.name)
+
case StructType(fields) =>
fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) =>
builder.addField(convertField(field))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
index 9535bbd585bce..e410789504e70 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
@@ -238,6 +238,18 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
case DecimalType.Fixed(precision, scale) =>
makeDecimalWriter(precision, scale)
+ case VariantType =>
+ (row: SpecializedGetters, ordinal: Int) =>
+ val v = row.getVariant(ordinal)
+ consumeGroup {
+ consumeField("value", 0) {
+ recordConsumer.addBinary(Binary.fromReusedByteArray(v.getValue))
+ }
+ consumeField("metadata", 1) {
+ recordConsumer.addBinary(Binary.fromReusedByteArray(v.getMetadata))
+ }
+ }
+
case t: StructType =>
val fieldWriters = t.map(_.dataType).map(makeWriter).toArray[ValueWriter]
(row: SpecializedGetters, ordinal: Int) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
index 776192992789a..300c0f5004252 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
@@ -140,6 +140,8 @@ class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def equals(other: Any): Boolean = other.isInstanceOf[XmlFileFormat]
override def supportDataType(dataType: DataType): Boolean = dataType match {
+ case _: VariantType => false
+
case _: AtomicType => true
case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 017cc474ea028..8e6bad11c09a9 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -429,6 +429,7 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.VariancePop | var_pop | SELECT var_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct |
| org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct |
| org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance | SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | struct |
+| org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct |
| org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | SELECT xpath_boolean('1','a/b') | struct1, a/b):boolean> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathDouble | xpath_double | SELECT xpath_double('12', 'sum(a/b)') | struct12, sum(a/b)):double> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathDouble | xpath_number | SELECT xpath_number('12', 'sum(a/b)') | struct12, sum(a/b)):double> |
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
index f88dcbd465852..10fcee1469398 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
@@ -320,6 +320,7 @@ VALUES false
VAR false
VARCHAR false
VARIABLE false
+VARIANT false
VERSION false
VIEW false
VIEWS false
diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
index b618299ea61a8..be2303a716da5 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
@@ -320,6 +320,7 @@ VALUES false
VAR false
VARCHAR false
VARIABLE false
+VARIANT false
VERSION false
VIEW false
VIEWS false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
new file mode 100644
index 0000000000000..dde986c555b10
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.File
+
+import scala.util.Random
+
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.VariantVal
+
+class VariantSuite extends QueryTest with SharedSparkSession {
+ test("basic tests") {
+ def verifyResult(df: DataFrame): Unit = {
+ val result = df.collect()
+ .map(_.get(0).asInstanceOf[VariantVal].toString)
+ .sorted
+ .toSeq
+ val expected = (1 until 10).map(id => "1" * id)
+ assert(result == expected)
+ }
+
+ // At this point, JSON parsing logic is not really implemented. We just construct some number
+ // inputs that are also valid JSON. This exercises passing VariantVal throughout the system.
+ val query = spark.sql("select parse_json(repeat('1', id)) as v from range(1, 10)")
+ verifyResult(query)
+
+ // Write into and read from Parquet.
+ withTempDir { dir =>
+ val tempDir = new File(dir, "files").getCanonicalPath
+ query.write.parquet(tempDir)
+ verifyResult(spark.read.parquet(tempDir))
+ }
+ }
+
+ test("round trip tests") {
+ val rand = new Random(42)
+ val input = Seq.fill(50) {
+ if (rand.nextInt(10) == 0) {
+ null
+ } else {
+ val value = new Array[Byte](rand.nextInt(50))
+ rand.nextBytes(value)
+ val metadata = new Array[Byte](rand.nextInt(50))
+ rand.nextBytes(metadata)
+ new VariantVal(value, metadata)
+ }
+ }
+
+ val df = spark.createDataFrame(
+ spark.sparkContext.parallelize(input.map(Row(_))),
+ StructType.fromDDL("v variant")
+ )
+ val result = df.collect().map(_.get(0).asInstanceOf[VariantVal])
+
+ def prepareAnswer(values: Seq[VariantVal]): Seq[String] = {
+ values.map(v => if (v == null) "null" else v.debugString()).sorted
+ }
+ assert(prepareAnswer(input) == prepareAnswer(result))
+ }
+}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
index 72e6fae92cbc2..9bb35bb8719ea 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
@@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer {
val sessionHandle = client.openSession(user, "")
val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS)
// scalastyle:off line.size.limit
- assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE")
+ assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE")
// scalastyle:on line.size.limit
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
index 3cf6fcbc65ace..5ccd40aefa255 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
@@ -191,6 +191,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
}
override def supportDataType(dataType: DataType): Boolean = dataType match {
+ case _: VariantType => false
+
case _: AnsiIntervalType => false
case _: AtomicType => true