diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java new file mode 100644 index 0000000000000..40bd1c7abc75f --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/VariantVal.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import org.apache.spark.unsafe.Platform; + +import java.io.Serializable; +import java.util.Arrays; + +/** + * The physical data representation of {@link org.apache.spark.sql.types.VariantType} that + * represents a semi-structured value. It consists of two binary values: {@link VariantVal#value} + * and {@link VariantVal#metadata}. The value encodes types and values, but not field names. The + * metadata currently contains a version flag and a list of field names. We can extend/modify the + * detailed binary format given the version flag. + *

+ * A {@link VariantVal} can be produced by casting another value into the Variant type or parsing a + * JSON string in the {@link org.apache.spark.sql.catalyst.expressions.variant.ParseJson} + * expression. We can extract a path consisting of field names and array indices from it, cast it + * into a concrete data type, or rebuild a JSON string from it. + *

+ * The storage layout of this class in {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} + * and {@link org.apache.spark.sql.catalyst.expressions.UnsafeArrayData} is: the fixed-size part is + * a long value "offsetAndSize". The upper 32 bits is the offset that points to the start position + * of the actual binary content. The lower 32 bits is the total length of the binary content. The + * binary content contains: 4 bytes representing the length of {@link VariantVal#value}, content of + * {@link VariantVal#value}, content of {@link VariantVal#metadata}. This is an internal and + * transient format and can be modified at any time. + */ +public class VariantVal implements Serializable { + protected final byte[] value; + protected final byte[] metadata; + + public VariantVal(byte[] value, byte[] metadata) { + this.value = value; + this.metadata = metadata; + } + + public byte[] getValue() { + return value; + } + + public byte[] getMetadata() { + return metadata; + } + + /** + * This function reads the binary content described in `writeIntoUnsafeRow` from `baseObject`. The + * offset is computed by adding the offset in {@code offsetAndSize} and {@code baseOffset}. + */ + public static VariantVal readFromUnsafeRow( + long offsetAndSize, + Object baseObject, + long baseOffset) { + // offset and totalSize is the upper/lower 32 bits in offsetAndSize. + int offset = (int) (offsetAndSize >> 32); + int totalSize = (int) offsetAndSize; + int valueSize = Platform.getInt(baseObject, baseOffset + offset); + int metadataSize = totalSize - 4 - valueSize; + byte[] value = new byte[valueSize]; + byte[] metadata = new byte[metadataSize]; + Platform.copyMemory( + baseObject, + baseOffset + offset + 4, + value, + Platform.BYTE_ARRAY_OFFSET, + valueSize + ); + Platform.copyMemory( + baseObject, + baseOffset + offset + 4 + valueSize, + metadata, + Platform.BYTE_ARRAY_OFFSET, + metadataSize + ); + return new VariantVal(value, metadata); + } + + public String debugString() { + return "VariantVal{" + + "value=" + Arrays.toString(value) + + ", metadata=" + Arrays.toString(metadata) + + '}'; + } + + /** + * @return A human-readable representation of the Variant value. It is always a JSON string at + * this moment. + */ + @Override + public String toString() { + // NOTE: the encoding is not yet implemented, this is not the final implementation. + return new String(value); + } +} diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index e3b9f3161b24d..1b4c10acaf7bc 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -317,6 +317,12 @@ ], "sqlState" : "58030" }, + "CANNOT_SAVE_VARIANT" : { + "message" : [ + "Cannot save variant data type into external storage." + ], + "sqlState" : "0A000" + }, "CANNOT_UPDATE_FIELD" : { "message" : [ "Cannot update field type:" diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 0e27e4a604c46..e235c13d413e2 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -72,6 +72,8 @@ private[sql] object AvroUtils extends Logging { } def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AtomicType => true case st: StructType => st.forall { f => supportsDataType(f.dataType) } diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index a811019e0a57b..ee9c2fd67b307 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -280,6 +280,12 @@ SQLSTATE: 58030 Failed to set permissions on created path `` back to ``. +### CANNOT_SAVE_VARIANT + +[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported) + +Cannot save variant data type into external storage. + ### [CANNOT_UPDATE_FIELD](sql-error-conditions-cannot-update-field-error-class.html) [SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 09c38a0099599..4729db16d63f3 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -671,6 +671,7 @@ Below is a list of all the keywords in Spark SQL. |VARCHAR|non-reserved|non-reserved|reserved| |VAR|non-reserved|non-reserved|non-reserved| |VARIABLE|non-reserved|non-reserved|non-reserved| +|VARIANT|non-reserved|non-reserved|reserved| |VERSION|non-reserved|non-reserved|non-reserved| |VIEW|non-reserved|non-reserved|non-reserved| |VIEWS|non-reserved|non-reserved|non-reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index e8b5cb012fcae..9b3dcbc6d194f 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -408,6 +408,7 @@ VALUES: 'VALUES'; VARCHAR: 'VARCHAR'; VAR: 'VAR'; VARIABLE: 'VARIABLE'; +VARIANT: 'VARIANT'; VERSION: 'VERSION'; VIEW: 'VIEW'; VIEWS: 'VIEWS'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index bd449a4e194e8..609bd72e21935 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1086,6 +1086,7 @@ type | DECIMAL | DEC | NUMERIC | VOID | INTERVAL + | VARIANT | ARRAY | STRUCT | MAP | unsupportedType=identifier ; @@ -1545,6 +1546,7 @@ ansiNonReserved | VARCHAR | VAR | VARIABLE + | VARIANT | VERSION | VIEW | VIEWS @@ -1893,6 +1895,7 @@ nonReserved | VARCHAR | VAR | VARIABLE + | VARIANT | VERSION | VIEW | VIEWS diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index e5e9ba644b814..9133abce88adc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -25,7 +25,7 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.SparkClassUtils /** @@ -216,6 +216,7 @@ object AgnosticEncoders { case object CalendarIntervalEncoder extends LeafEncoder[CalendarInterval](CalendarIntervalType) case object DayTimeIntervalEncoder extends LeafEncoder[Duration](DayTimeIntervalType()) case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType()) + case object VariantEncoder extends LeafEncoder[VariantVal](VariantType) case class DateEncoder(override val lenientSerialization: Boolean) extends LeafEncoder[jsql.Date](DateType) case class LocalDateEncoder(override val lenientSerialization: Boolean) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 69661c343c5b1..a201da9c95c9e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.reflect.classTag import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VariantEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ @@ -90,6 +90,7 @@ object RowEncoder { case CalendarIntervalType => CalendarIntervalEncoder case _: DayTimeIntervalType => DayTimeIntervalEncoder case _: YearMonthIntervalType => YearMonthIntervalEncoder + case _: VariantType => VariantEncoder case p: PythonUserDefinedType => // TODO check if this works. encoderForDataType(p.sqlType, lenient) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index b30c6fa29e829..3a2e704ffe9f7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin} import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { protected def typedVisit[T](ctx: ParseTree): T = { @@ -82,6 +82,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { DecimalType(precision.getText.toInt, scale.getText.toInt) case (VOID, Nil) => NullType case (INTERVAL, Nil) => CalendarIntervalType + case (VARIANT, Nil) => VariantType case (CHARACTER | CHAR | VARCHAR, Nil) => throw QueryParsingErrors.charTypeMissingLengthError(ctx.`type`.getText, ctx) case (ARRAY | STRUCT | MAP, Nil) => diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala index 5f563e3b7a8f1..94252de48d1ea 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -173,7 +173,8 @@ object DataType { YearMonthIntervalType(YEAR), YearMonthIntervalType(MONTH), YearMonthIntervalType(YEAR, MONTH), - TimestampNTZType) + TimestampNTZType, + VariantType) .map(t => t.typeName -> t).toMap } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala new file mode 100644 index 0000000000000..103fe7a59fc83 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/VariantType.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.annotation.Unstable + +/** + * The data type representing semi-structured values with arbitrary hierarchical data structures. It + * is intended to store parsed JSON values and most other data types in the system (e.g., it cannot + * store a map with a non-string key type). + * + * @since 4.0.0 + */ +@Unstable +class VariantType private () extends AtomicType { + // The default size is used in query planning to drive optimization decisions. 2048 is arbitrarily + // picked and we currently don't have any data to support it. This may need revisiting later. + override def defaultSize: Int = 2048 + + /** This is a no-op because values with VARIANT type are always nullable. */ + private[spark] override def asNullable: VariantType = this +} + +/** + * @since 4.0.0 + */ +@Unstable +case object VariantType extends VariantType diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java index e7af5a68b4663..ffc3c8eaf8f84 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -45,7 +45,7 @@ public class ExpressionInfo { "collection_funcs", "predicate_funcs", "conditional_funcs", "conversion_funcs", "csv_funcs", "datetime_funcs", "generator_funcs", "hash_funcs", "json_funcs", "lambda_funcs", "map_funcs", "math_funcs", "misc_funcs", "string_funcs", "struct_funcs", - "window_funcs", "xml_funcs", "table_funcs", "url_funcs")); + "window_funcs", "xml_funcs", "table_funcs", "url_funcs", "variant_funcs")); private static final Set validSources = new HashSet<>(Arrays.asList("built-in", "hive", "python_udf", "scala_udf", diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index eea7149d02594..b88a892db4b46 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; public interface SpecializedGetters { @@ -51,6 +52,8 @@ public interface SpecializedGetters { CalendarInterval getInterval(int ordinal); + VariantVal getVariant(int ordinal); + InternalRow getStruct(int ordinal, int numFields); ArrayData getArray(int ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java index 91f04c3d327ac..9e508dbb271cf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGettersReader.java @@ -66,6 +66,9 @@ public static Object read( if (physicalDataType instanceof PhysicalBinaryType) { return obj.getBinary(ordinal); } + if (physicalDataType instanceof PhysicalVariantType) { + return obj.getVariant(ordinal); + } if (physicalDataType instanceof PhysicalStructType) { return obj.getStruct(ordinal, ((PhysicalStructType) physicalDataType).fields().length); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index ea6f1e05422b5..700e42cb843c8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -38,6 +38,7 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; @@ -231,6 +232,12 @@ public CalendarInterval getInterval(int ordinal) { return new CalendarInterval(months, days, microseconds); } + @Override + public VariantVal getVariant(int ordinal) { + if (isNullAt(ordinal)) return null; + return VariantVal.readFromUnsafeRow(getLong(ordinal), baseObject, baseOffset); + } + @Override public UnsafeRow getStruct(int ordinal, int numFields) { if (isNullAt(ordinal)) return null; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 8f9d5919e1d9f..fca45c58beed0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -36,6 +36,7 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; @@ -417,6 +418,12 @@ public CalendarInterval getInterval(int ordinal) { } } + @Override + public VariantVal getVariant(int ordinal) { + if (isNullAt(ordinal)) return null; + return VariantVal.readFromUnsafeRow(getLong(ordinal), baseObject, baseOffset); + } + @Override public UnsafeRow getStruct(int ordinal, int numFields) { if (isNullAt(ordinal)) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 8d4e187d01a12..d651e5ab5b3e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -25,6 +25,7 @@ import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * Base class for writing Unsafe* structures. @@ -149,6 +150,27 @@ public void write(int ordinal, CalendarInterval input) { increaseCursor(16); } + public void write(int ordinal, VariantVal input) { + // See the class comment of VariantVal for the format of the binary content. + byte[] value = input.getValue(); + byte[] metadata = input.getMetadata(); + int totalSize = 4 + value.length + metadata.length; + int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSize); + grow(roundedSize); + zeroOutPaddingBytes(totalSize); + Platform.putInt(getBuffer(), cursor(), value.length); + Platform.copyMemory(value, Platform.BYTE_ARRAY_OFFSET, getBuffer(), cursor() + 4, value.length); + Platform.copyMemory( + metadata, + Platform.BYTE_ARRAY_OFFSET, + getBuffer(), + cursor() + 4 + value.length, + metadata.length + ); + setOffsetAndSize(ordinal, totalSize); + increaseCursor(roundedSize); + } + public final void write(int ordinal, UnsafeRow row) { writeAlignedBytes(ordinal, row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes()); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 73c2cf2cc05f8..cd3c30fa69335 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -22,6 +22,7 @@ import org.apache.spark.sql.types.UserDefinedType; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * An interface representing in-memory columnar data in Spark. This interface defines the main APIs @@ -299,6 +300,16 @@ public CalendarInterval getInterval(int rowId) { return new CalendarInterval(months, days, microseconds); } + /** + * Returns the Variant value for {@code rowId}. Similar to {@link #getInterval(int)}, the + * implementation must implement {@link #getChild(int)} and define 2 child vectors of binary type + * for the Variant value and metadata. + */ + public final VariantVal getVariant(int rowId) { + if (isNullAt(rowId)) return null; + return new VariantVal(getChild(0).getBinary(rowId), getChild(1).getBinary(rowId)); + } + /** * @return child {@link ColumnVector} at the given ordinal. */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index bd7c3d7c0fd49..e0141a575b299 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * Array abstraction in {@link ColumnVector}. @@ -160,6 +161,11 @@ public CalendarInterval getInterval(int ordinal) { return data.getInterval(offset + ordinal); } + @Override + public VariantVal getVariant(int ordinal) { + return data.getVariant(offset + ordinal); + } + @Override public ColumnarRow getStruct(int ordinal, int numFields) { return data.getStruct(offset + ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java index c0d2ae8e7d0e8..ac23f70584e89 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * This class wraps an array of {@link ColumnVector} and provides a row view. @@ -133,6 +134,11 @@ public CalendarInterval getInterval(int ordinal) { return columns[ordinal].getInterval(rowId); } + @Override + public VariantVal getVariant(int ordinal) { + return columns[ordinal].getVariant(rowId); + } + @Override public ColumnarRow getStruct(int ordinal, int numFields) { return columns[ordinal].getStruct(rowId); @@ -182,6 +188,8 @@ public Object get(int ordinal, DataType dataType) { return getStruct(ordinal, ((StructType)dataType).fields().length); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof VariantType) { + return getVariant(ordinal); } else { throw new UnsupportedOperationException("Datatype not supported " + dataType); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 1df4653f55276..18f6779cccb96 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * Row abstraction in {@link ColumnVector}. @@ -140,6 +141,11 @@ public CalendarInterval getInterval(int ordinal) { return data.getChild(ordinal).getInterval(rowId); } + @Override + public VariantVal getVariant(int ordinal) { + return data.getChild(ordinal).getVariant(rowId); + } + @Override public ColumnarRow getStruct(int ordinal, int numFields) { return data.getChild(ordinal).getStruct(rowId); @@ -187,6 +193,8 @@ public Object get(int ordinal, DataType dataType) { return getStruct(ordinal, ((StructType)dataType).fields().length); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof VariantType) { + return getVariant(ordinal); } else { throw new UnsupportedOperationException("Datatype not supported " + dataType); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala index 429ce805bf2c4..034b959c5a383 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ProjectingInternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, Decimal, StructType} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} /** * An [[InternalRow]] that projects particular columns from another [[InternalRow]] without copying @@ -99,6 +99,10 @@ case class ProjectingInternalRow(schema: StructType, colOrdinals: Seq[Int]) exte row.getInterval(colOrdinals(ordinal)) } + override def getVariant(ordinal: Int): VariantVal = { + row.getVariant(colOrdinals(ordinal)) + } + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { row.getStruct(colOrdinals(ordinal), numFields) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 23d63011db53f..4fb8d88f6eab1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.variant._ import org.apache.spark.sql.catalyst.expressions.xml._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range} import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -808,6 +809,9 @@ object FunctionRegistry { expression[LengthOfJsonArray]("json_array_length"), expression[JsonObjectKeys]("json_object_keys"), + // Variant + expression[ParseJson]("parse_json"), + // cast expression[Cast]("cast"), // Cast aliases (SPARK-16730) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala index 4540ecffe0d21..793dd373d6899 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.collection.Map import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, CalendarIntervalEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, SparkDecimalEncoder, VariantEncoder} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalIntegerType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} @@ -68,6 +68,7 @@ object EncoderUtils { case CalendarIntervalEncoder => true case BinaryEncoder => true case _: SparkDecimalEncoder => true + case VariantEncoder => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala index 6aa5fefc73902..50408b41c1a76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedUnsafeProjection.scala @@ -162,6 +162,8 @@ object InterpretedUnsafeProjection { case PhysicalStringType => (v, i) => writer.write(i, v.getUTF8String(i)) + case PhysicalVariantType => (v, i) => writer.write(i, v.getVariant(i)) + case PhysicalStructType(fields) => val numFields = fields.length val rowWriter = new UnsafeRowWriter(writer, numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index 86871223d66ad..345f2b3030b58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} /** * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to @@ -120,6 +120,9 @@ class JoinedRow extends InternalRow { override def getInterval(i: Int): CalendarInterval = if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) + override def getVariant(i: Int): VariantVal = + if (i < row1.numFields) row1.getVariant(i) else row2.getVariant(i - row1.numFields) + override def getMap(i: Int): MapData = if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3595e43fcb987..4c32f682c275f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1475,6 +1475,7 @@ object CodeGenerator extends Logging { classOf[UTF8String].getName, classOf[Decimal].getName, classOf[CalendarInterval].getName, + classOf[VariantVal].getName, classOf[ArrayData].getName, classOf[UnsafeArrayData].getName, classOf[MapData].getName, @@ -1641,6 +1642,7 @@ object CodeGenerator extends Logging { case PhysicalNullType => "null" case PhysicalStringType => s"$input.getUTF8String($ordinal)" case t: PhysicalStructType => s"$input.getStruct($ordinal, ${t.fields.size})" + case PhysicalVariantType => s"$input.getVariant($ordinal)" case _ => s"($jt)$input.get($ordinal, null)" } } @@ -1928,6 +1930,7 @@ object CodeGenerator extends Logging { case PhysicalShortType => JAVA_SHORT case PhysicalStringType => "UTF8String" case _: PhysicalStructType => "InternalRow" + case _: PhysicalVariantType => "VariantVal" case _ => "Object" } } @@ -1951,6 +1954,7 @@ object CodeGenerator extends Logging { case _: MapType => classOf[MapData] case udt: UserDefinedType[_] => javaClass(udt.sqlType) case ObjectType(cls) => cls + case VariantType => classOf[VariantVal] case _ => classOf[Object] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 217ed562db779..c406ba0707b3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -97,6 +97,7 @@ object Literal { val convert = CatalystTypeConverters.createToCatalystConverter(dataType) Literal(convert(a), dataType) case i: CalendarInterval => Literal(i, CalendarIntervalType) + case v: VariantVal => Literal(v, VariantType) case null => Literal(null, NullType) case v: Literal => v case _ => @@ -143,6 +144,7 @@ object Literal { case _ if clz == classOf[BigInt] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[BigDecimal] => DecimalType.SYSTEM_DEFAULT case _ if clz == classOf[CalendarInterval] => CalendarIntervalType + case _ if clz == classOf[VariantVal] => VariantType case _ if clz.isArray => ArrayType(componentTypeToDataType(clz.getComponentType)) @@ -235,6 +237,7 @@ object Literal { case PhysicalNullType => true case PhysicalShortType => v.isInstanceOf[Short] case PhysicalStringType => v.isInstanceOf[UTF8String] + case PhysicalVariantType => v.isInstanceOf[VariantVal] case st: PhysicalStructType => v.isInstanceOf[InternalRow] && { val row = v.asInstanceOf[InternalRow] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 296d093a13de6..8379069c53d9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} import org.apache.spark.util.ArrayImplicits._ /** @@ -47,6 +47,7 @@ trait BaseGenericInternalRow extends InternalRow { override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) override def getArray(ordinal: Int): ArrayData = getAs(ordinal) override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getVariant(ordinal: Int): VariantVal = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala new file mode 100644 index 0000000000000..136ae4a3ef436 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.variant + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types._ + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr) - Parse a JSON string as an Variant value. Throw an exception when the string is not valid JSON value.", + examples = """ + Examples: + > SELECT _FUNC_('{"a":1,"b":0.8}'); + {"a":1,"b":0.8} + """, + since = "4.0.0", + group = "variant_funcs" +) +// scalastyle:on line.size.limit +case class ParseJson(child: Expression) extends UnaryExpression + with NullIntolerant with ExpectsInputTypes with CodegenFallback { + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def dataType: DataType = VariantType + + override def prettyName: String = "parse_json" + + protected override def nullSafeEval(input: Any): Any = { + // A dummy implementation: the value is the raw bytes of the input string. This is not the final + // implementation, but only intended for debugging. + // TODO(SPARK-45891): Have an actual parse_json implementation. + new VariantVal(input.asInstanceOf[UTF8String].toString.getBytes, Array()) + } + + override protected def withNewChildInternal(newChild: Expression): ParseJson = + copy(child = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index 29d7a39ace3c1..290a35eb8e3b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -23,8 +23,8 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, InterpretedOrdering, SortOrder} import org.apache.spark.sql.catalyst.util.{ArrayData, SQLOrderingUtil} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, YearMonthIntervalType} -import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType} +import org.apache.spark.unsafe.types.{ByteArray, UTF8String, VariantVal} import org.apache.spark.util.ArrayImplicits._ sealed abstract class PhysicalDataType { @@ -58,6 +58,7 @@ object PhysicalDataType { case StructType(fields) => PhysicalStructType(fields) case MapType(keyType, valueType, valueContainsNull) => PhysicalMapType(keyType, valueType, valueContainsNull) + case VariantType => PhysicalVariantType case _ => UninitializedPhysicalType } @@ -327,6 +328,18 @@ case class PhysicalStructType(fields: Array[StructField]) extends PhysicalDataTy } } +class PhysicalVariantType extends PhysicalDataType { + private[sql] type InternalType = VariantVal + @transient private[sql] lazy val tag = typeTag[InternalType] + + // TODO(SPARK-45891): Support comparison for the Variant type. + override private[sql] def ordering = + throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( + "PhysicalVariantType") +} + +object PhysicalVariantType extends PhysicalVariantType + object UninitializedPhysicalType extends PhysicalDataType { override private[sql] def ordering = throw QueryExecutionErrors.orderedOperationUnsupportedByDataTypeError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index bdf8d36321e64..7ff36bef5a4b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, Decimal} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} class GenericArrayData(val array: Array[Any]) extends ArrayData { @@ -73,6 +73,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + override def getVariant(ordinal: Int): VariantVal = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) override def getArray(ordinal: Int): ArrayData = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index c3249a4c02d8c..9cc99e9bfa335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1564,6 +1564,12 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map.empty) } + def cannotSaveVariantIntoExternalStorageError(): Throwable = { + new AnalysisException( + errorClass = "CANNOT_SAVE_VARIANT", + messageParameters = Map.empty) + } + def cannotResolveAttributeError(name: String, outputStr: String): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1137", diff --git a/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt b/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt index 921491a4a4761..47a3f02ac1656 100644 --- a/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt +++ b/sql/catalyst/src/test/resources/ansi-sql-2016-reserved-keywords.txt @@ -355,6 +355,7 @@ VAR_POP VAR_SAMP VARBINARY VARCHAR +VARIANT VARYING VERSIONING WHEN diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala index 01aa3579aea98..eeb05139a3e5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} class GenerateUnsafeProjectionSuite extends SparkFunSuite { test("Test unsafe projection string access pattern") { @@ -88,6 +88,7 @@ object AlwaysNull extends InternalRow { override def getUTF8String(ordinal: Int): UTF8String = notSupported override def getBinary(ordinal: Int): Array[Byte] = notSupported override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getVariant(ordinal: Int): VariantVal = notSupported override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported override def getArray(ordinal: Int): ArrayData = notSupported override def getMap(ordinal: Int): MapData = notSupported @@ -117,6 +118,7 @@ object AlwaysNonNull extends InternalRow { override def getUTF8String(ordinal: Int): UTF8String = UTF8String.fromString("test") override def getBinary(ordinal: Int): Array[Byte] = notSupported override def getInterval(ordinal: Int): CalendarInterval = notSupported + override def getVariant(ordinal: Int): VariantVal = notSupported override def getStruct(ordinal: Int, numFields: Int): InternalRow = notSupported override def getArray(ordinal: Int): ArrayData = stringToUTF8Array(Array("1", "2", "3")) val keyArray = stringToUTF8Array(Array("1", "2", "3")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala index eaed279679251..e2a416b773aa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types.Decimal -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} class UnsafeRowWriterSuite extends SparkFunSuite { @@ -61,4 +61,15 @@ class UnsafeRowWriterSuite extends SparkFunSuite { rowWriter.write(1, interval) assert(rowWriter.getRow.getInterval(1) === interval) } + + test("write and get variant through UnsafeRowWriter") { + val rowWriter = new UnsafeRowWriter(2) + rowWriter.resetRowWriter() + rowWriter.setNullAt(0) + assert(rowWriter.getRow.isNullAt(0)) + assert(rowWriter.getRow.getVariant(0) === null) + val variant = new VariantVal(Array[Byte](1, 2, 3), Array[Byte](-1, -2, -3, -4)) + rowWriter.write(1, variant) + assert(rowWriter.getRow.getVariant(1).debugString() == variant.debugString()) + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 7b841ab9933e2..29c106651acf0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -38,6 +38,7 @@ import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly @@ -89,6 +90,8 @@ public static void populate(ConstantColumnVector col, InternalRow row, int field } else if (pdt instanceof PhysicalCalendarIntervalType) { // The value of `numRows` is irrelevant. col.setCalendarInterval((CalendarInterval) row.get(fieldIdx, t)); + } else if (pdt instanceof PhysicalVariantType) { + col.setVariant((VariantVal)row.get(fieldIdx, t)); } else { throw new RuntimeException(String.format("DataType %s is not supported" + " in column vectorized reader.", t.sql())); @@ -124,7 +127,7 @@ public static Map toJavaIntMap(ColumnarMap map) { private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { - if (t instanceof CalendarIntervalType) { + if (t instanceof CalendarIntervalType || t instanceof VariantType) { dst.appendStruct(true); } else { dst.appendNull(); @@ -167,6 +170,11 @@ private static void appendValue(WritableColumnVector dst, DataType t, Object o) dst.getChild(0).appendInt(c.months); dst.getChild(1).appendInt(c.days); dst.getChild(2).appendLong(c.microseconds); + } else if (t instanceof VariantType) { + VariantVal v = (VariantVal) o; + dst.appendStruct(false); + dst.getChild(0).appendByteArray(v.getValue(), 0, v.getValue().length); + dst.getChild(1).appendByteArray(v.getMetadata(), 0, v.getMetadata().length); } else if (t instanceof DateType) { dst.appendInt(DateTimeUtils.fromJavaDate((Date) o)); } else if (t instanceof TimestampType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java index 5095e6b0c9c6b..43854c2300fde 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * This class adds the constant support to ColumnVector. @@ -67,6 +68,10 @@ public ConstantColumnVector(int numRows, DataType type) { this.childData[0] = new ConstantColumnVector(1, DataTypes.IntegerType); this.childData[1] = new ConstantColumnVector(1, DataTypes.IntegerType); this.childData[2] = new ConstantColumnVector(1, DataTypes.LongType); + } else if (type instanceof VariantType) { + this.childData = new ConstantColumnVector[2]; + this.childData[0] = new ConstantColumnVector(1, DataTypes.BinaryType); + this.childData[1] = new ConstantColumnVector(1, DataTypes.BinaryType); } else { this.childData = null; } @@ -307,4 +312,12 @@ public void setCalendarInterval(CalendarInterval value) { this.childData[1].setInt(value.days); this.childData[2].setLong(value.microseconds); } + + /** + * Sets the Variant `value` for all rows + */ + public void setVariant(VariantVal value) { + this.childData[0].setBinary(value.getValue()); + this.childData[1].setBinary(value.getMetadata()); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index eda58815f3b3a..0a110a204e04b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -28,6 +28,7 @@ import org.apache.spark.sql.vectorized.ColumnarRow; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; /** * A mutable version of {@link ColumnarRow}, which is used in the vectorized hash map for hash @@ -142,6 +143,11 @@ public CalendarInterval getInterval(int ordinal) { return columns[ordinal].getInterval(rowId); } + @Override + public VariantVal getVariant(int ordinal) { + return columns[ordinal].getVariant(rowId); + } + @Override public ColumnarRow getStruct(int ordinal, int numFields) { return columns[ordinal].getStruct(rowId); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 4c8ceff356595..10907c69c2260 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -694,7 +694,7 @@ public final int appendStruct(boolean isNull) { putNull(elementsAppended); elementsAppended++; for (WritableColumnVector c: childColumns) { - if (c.type instanceof StructType) { + if (c.type instanceof StructType || c.type instanceof VariantType) { c.appendStruct(true); } else { c.appendNull(); @@ -975,6 +975,10 @@ protected WritableColumnVector(int capacity, DataType dataType) { this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[1] = reserveNewColumn(capacity, DataTypes.IntegerType); this.childColumns[2] = reserveNewColumn(capacity, DataTypes.LongType); + } else if (type instanceof VariantType) { + this.childColumns = new WritableColumnVector[2]; + this.childColumns[0] = reserveNewColumn(capacity, DataTypes.BinaryType); + this.childColumns[1] = reserveNewColumn(capacity, DataTypes.BinaryType); } else { this.childColumns = null; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala index 3fec13a7f9ba9..7c117e0cace97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala @@ -267,6 +267,7 @@ private object RowToColumnConverter { case DoubleType => DoubleConverter case StringType => StringConverter case CalendarIntervalType => CalendarConverter + case VariantType => VariantConverter case at: ArrayType => ArrayConverter(getConverterForType(at.elementType, at.containsNull)) case st: StructType => new StructConverter(st.fields.map( (f) => getConverterForType(f.dataType, f.nullable))) @@ -346,6 +347,15 @@ private object RowToColumnConverter { } } + private object VariantConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + val v = row.getVariant(column) + cv.appendStruct(false) + cv.getChild(0).appendByteArray(v.getValue, 0, v.getValue.length) + cv.getChild(1).appendByteArray(v.getMetadata, 0, v.getMetadata.length) + } + } + private case class ArrayConverter(childConverter: TypeConverter) extends TypeConverter { override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { val values = row.getArray(column) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 9811a1d3f33e4..f6b5ba15afbd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedComm import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTablesExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.ArrayImplicits._ /** @@ -131,6 +131,7 @@ object HiveResult { HIVE_STYLE, startField, endField) + case (v: VariantVal, VariantType) => v.toString case (other, _: UserDefinedType[_]) => other.toString } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index cd295f3b17bd6..835308f3d0248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, Tex import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructField, StructType, VariantType} import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.{HadoopFSUtils, ThreadUtils, Utils} import org.apache.spark.util.ArrayImplicits._ @@ -503,6 +503,7 @@ case class DataSource( providingInstance() match { case dataSource: CreatableRelationProvider => disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = true) + disallowWritingVariant(outputColumns.map(_.dataType)) dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => @@ -524,6 +525,7 @@ case class DataSource( providingInstance() match { case dataSource: CreatableRelationProvider => disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = true) + disallowWritingVariant(data.schema.map(_.dataType)) SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = false) @@ -560,6 +562,14 @@ case class DataSource( throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError() }) } + + private def disallowWritingVariant(dataTypes: Seq[DataType]): Unit = { + dataTypes.foreach { dt => + if (dt.existsRecursively(_.isInstanceOf[VariantType])) { + throw QueryCompilationErrors.cannotSaveVariantIntoExternalStorageError() + } + } + } } object DataSource extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 069ad9562a7d5..32370562003f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -145,6 +145,8 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[CSVFileFormat] override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: BinaryType => false case _: AtomicType => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 9c6c77a8b9622..7fb6e98fb0468 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -134,6 +134,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[JsonFileFormat] override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AtomicType => true case st: StructType => st.forall { f => supportDataType(f.dataType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index b7e6f11f67d69..623f97499cd55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -237,6 +237,8 @@ class OrcFileFormat } override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AtomicType => true case st: StructType => st.forall { f => supportDataType(f.dataType) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index eedd165278aed..f60f7c11eefa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -185,6 +185,11 @@ class ParquetToSparkSchemaConverter( } field match { case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType) + case groupColumn: GroupColumnIO if targetType.contains(VariantType) => + ParquetColumn(VariantType, groupColumn, Seq( + convertField(groupColumn.getChild(0), Some(BinaryType)), + convertField(groupColumn.getChild(1), Some(BinaryType)) + )) case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType) } } @@ -719,6 +724,12 @@ class SparkToParquetSchemaConverter( // Other types // =========== + case VariantType => + Types.buildGroup(repetition) + .addField(convertField(StructField("value", BinaryType, nullable = false))) + .addField(convertField(StructField("metadata", BinaryType, nullable = false))) + .named(field.name) + case StructType(fields) => fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) => builder.addField(convertField(field)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index 9535bbd585bce..e410789504e70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -238,6 +238,18 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { case DecimalType.Fixed(precision, scale) => makeDecimalWriter(precision, scale) + case VariantType => + (row: SpecializedGetters, ordinal: Int) => + val v = row.getVariant(ordinal) + consumeGroup { + consumeField("value", 0) { + recordConsumer.addBinary(Binary.fromReusedByteArray(v.getValue)) + } + consumeField("metadata", 1) { + recordConsumer.addBinary(Binary.fromReusedByteArray(v.getMetadata)) + } + } + case t: StructType => val fieldWriters = t.map(_.dataType).map(makeWriter).toArray[ValueWriter] (row: SpecializedGetters, ordinal: Int) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala index 776192992789a..300c0f5004252 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala @@ -140,6 +140,8 @@ class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister { override def equals(other: Any): Boolean = other.isInstanceOf[XmlFileFormat] override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AtomicType => true case st: StructType => st.forall { f => supportDataType(f.dataType) } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 017cc474ea028..8e6bad11c09a9 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -429,6 +429,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.VariancePop | var_pop | SELECT var_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance | SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | +| org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct | | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | SELECT xpath_boolean('1','a/b') | struct1, a/b):boolean> | | org.apache.spark.sql.catalyst.expressions.xml.XPathDouble | xpath_double | SELECT xpath_double('12', 'sum(a/b)') | struct12, sum(a/b)):double> | | org.apache.spark.sql.catalyst.expressions.xml.XPathDouble | xpath_number | SELECT xpath_number('12', 'sum(a/b)') | struct12, sum(a/b)):double> | diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index f88dcbd465852..10fcee1469398 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -320,6 +320,7 @@ VALUES false VAR false VARCHAR false VARIABLE false +VARIANT false VERSION false VIEW false VIEWS false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index b618299ea61a8..be2303a716da5 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -320,6 +320,7 @@ VALUES false VAR false VARCHAR false VARIABLE false +VARIANT false VERSION false VIEW false VIEWS false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala new file mode 100644 index 0000000000000..dde986c555b10 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import scala.util.Random + +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.VariantVal + +class VariantSuite extends QueryTest with SharedSparkSession { + test("basic tests") { + def verifyResult(df: DataFrame): Unit = { + val result = df.collect() + .map(_.get(0).asInstanceOf[VariantVal].toString) + .sorted + .toSeq + val expected = (1 until 10).map(id => "1" * id) + assert(result == expected) + } + + // At this point, JSON parsing logic is not really implemented. We just construct some number + // inputs that are also valid JSON. This exercises passing VariantVal throughout the system. + val query = spark.sql("select parse_json(repeat('1', id)) as v from range(1, 10)") + verifyResult(query) + + // Write into and read from Parquet. + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + query.write.parquet(tempDir) + verifyResult(spark.read.parquet(tempDir)) + } + } + + test("round trip tests") { + val rand = new Random(42) + val input = Seq.fill(50) { + if (rand.nextInt(10) == 0) { + null + } else { + val value = new Array[Byte](rand.nextInt(50)) + rand.nextBytes(value) + val metadata = new Array[Byte](rand.nextInt(50)) + rand.nextBytes(metadata) + new VariantVal(value, metadata) + } + } + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(input.map(Row(_))), + StructType.fromDDL("v variant") + ) + val result = df.collect().map(_.get(0).asInstanceOf[VariantVal]) + + def prepareAnswer(values: Seq[VariantVal]): Seq[String] = { + values.map(v => if (v == null) "null" else v.debugString()).sorted + } + assert(prepareAnswer(input) == prepareAnswer(result)) + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 72e6fae92cbc2..9bb35bb8719ea 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BETWEEN,BIGINT,BINARY,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPUTE,CONCATENATE,CONSTRAINT,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DELETE,DELIMITED,DESC,DESCRIBE,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EXCEPT,EXCHANGE,EXCLUDE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IF,IGNORE,ILIKE,IMPORT,IN,INCLUDE,INDEX,INDEXES,INNER,INPATH,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,IS,ITEMS,JOIN,KEYS,LAST,LATERAL,LAZY,LEADING,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PERCENTILE_CONT,PERCENTILE_DISC,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 3cf6fcbc65ace..5ccd40aefa255 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -191,6 +191,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable } override def supportDataType(dataType: DataType): Boolean = dataType match { + case _: VariantType => false + case _: AnsiIntervalType => false case _: AtomicType => true