diff --git a/api/src/main/java/org/apache/iceberg/types/GetProjectedIds.java b/api/src/main/java/org/apache/iceberg/types/GetProjectedIds.java index a8a7de065ece..8f913a97becb 100644 --- a/api/src/main/java/org/apache/iceberg/types/GetProjectedIds.java +++ b/api/src/main/java/org/apache/iceberg/types/GetProjectedIds.java @@ -47,7 +47,7 @@ public Set struct(Types.StructType struct, List> fieldResu @Override public Set field(Types.NestedField field, Set fieldResult) { - if ((includeStructIds && field.type().isStructType()) || field.type().isPrimitiveType()) { + if ((includeStructIds && field.type().isStructType()) || field.type().isPrimitiveType() || field.type() instanceof Types.VariantType) { fieldIds.add(field.fieldId()); } return fieldIds; @@ -72,4 +72,9 @@ public Set map(Types.MapType map, Set keyResult, Set } return fieldIds; } + + @Override + public Set variant() { + return null; + } } diff --git a/core/src/main/java/org/apache/iceberg/variants/Variants.java b/core/src/main/java/org/apache/iceberg/variants/Variants.java index e10682fe544a..5d9af230928b 100644 --- a/core/src/main/java/org/apache/iceberg/variants/Variants.java +++ b/core/src/main/java/org/apache/iceberg/variants/Variants.java @@ -172,6 +172,10 @@ enum BasicType { ARRAY } + public static VariantMetadata emptyMetadata() { + return SerializedMetadata.EMPTY_V1_METADATA; + } + public static VariantMetadata metadata(ByteBuffer metadata) { return SerializedMetadata.from(metadata); } @@ -209,59 +213,59 @@ public static VariantPrimitive ofNull() { return new PrimitiveWrapper<>(PhysicalType.NULL, null); } - static VariantPrimitive of(boolean value) { + public static VariantPrimitive of(boolean value) { return new PrimitiveWrapper<>(PhysicalType.BOOLEAN_TRUE, value); } - static VariantPrimitive of(byte value) { + public static VariantPrimitive of(byte value) { return new PrimitiveWrapper<>(PhysicalType.INT8, value); } - static VariantPrimitive of(short value) { + public static VariantPrimitive of(short value) { return new PrimitiveWrapper<>(PhysicalType.INT16, value); } - static VariantPrimitive of(int value) { + public static VariantPrimitive of(int value) { return new PrimitiveWrapper<>(PhysicalType.INT32, value); } - static VariantPrimitive of(long value) { + public static VariantPrimitive of(long value) { return new PrimitiveWrapper<>(PhysicalType.INT64, value); } - static VariantPrimitive of(float value) { + public static VariantPrimitive of(float value) { return new PrimitiveWrapper<>(PhysicalType.FLOAT, value); } - static VariantPrimitive of(double value) { + public static VariantPrimitive of(double value) { return new PrimitiveWrapper<>(PhysicalType.DOUBLE, value); } - static VariantPrimitive ofDate(int value) { + public static VariantPrimitive ofDate(int value) { return new PrimitiveWrapper<>(PhysicalType.DATE, value); } - static VariantPrimitive ofIsoDate(String value) { + public static VariantPrimitive ofIsoDate(String value) { return ofDate(DateTimeUtil.isoDateToDays(value)); } - static VariantPrimitive ofTimestamptz(long value) { + public static VariantPrimitive ofTimestamptz(long value) { return new PrimitiveWrapper<>(PhysicalType.TIMESTAMPTZ, value); } - static VariantPrimitive ofIsoTimestamptz(String value) { + public static VariantPrimitive ofIsoTimestamptz(String value) { return ofTimestamptz(DateTimeUtil.isoTimestamptzToMicros(value)); } - static VariantPrimitive ofTimestampntz(long value) { + public static VariantPrimitive ofTimestampntz(long value) { return new PrimitiveWrapper<>(PhysicalType.TIMESTAMPNTZ, value); } - static VariantPrimitive ofIsoTimestampntz(String value) { + public static VariantPrimitive ofIsoTimestampntz(String value) { return ofTimestampntz(DateTimeUtil.isoTimestampToMicros(value)); } - static VariantPrimitive of(BigDecimal value) { + public static VariantPrimitive of(BigDecimal value) { int bitLength = value.unscaledValue().bitLength(); if (bitLength < 32) { return new PrimitiveWrapper<>(PhysicalType.DECIMAL4, value); @@ -274,11 +278,11 @@ static VariantPrimitive of(BigDecimal value) { throw new UnsupportedOperationException("Unsupported decimal precision: " + value.precision()); } - static VariantPrimitive of(ByteBuffer value) { + public static VariantPrimitive of(ByteBuffer value) { return new PrimitiveWrapper<>(PhysicalType.BINARY, value); } - static VariantPrimitive of(String value) { + public static VariantPrimitive of(String value) { return new PrimitiveWrapper<>(PhysicalType.STRING, value); } } diff --git a/core/src/test/java/org/apache/iceberg/variants/VariantTestUtil.java b/core/src/test/java/org/apache/iceberg/variants/VariantTestUtil.java index 576e06a9d1c5..9f1934836213 100644 --- a/core/src/test/java/org/apache/iceberg/variants/VariantTestUtil.java +++ b/core/src/test/java/org/apache/iceberg/variants/VariantTestUtil.java @@ -18,6 +18,8 @@ */ package org.apache.iceberg.variants; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; @@ -27,10 +29,55 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; public class VariantTestUtil { private VariantTestUtil() {} + public static void assertEqual(VariantMetadata expected, VariantMetadata actual) { + assertThat(actual).isNotNull(); + assertThat(expected).isNotNull(); + assertThat(actual.dictionarySize()) + .as("Dictionary size should match") + .isEqualTo(expected.dictionarySize()); + + for (int i = 0; i < expected.dictionarySize(); i += 1) { + assertThat(actual.get(i)).isEqualTo(expected.get(i)); + } + } + + public static void assertEqual(VariantValue expected, VariantValue actual) { + assertThat(actual).isNotNull(); + assertThat(expected).isNotNull(); + assertThat(actual.type()).as("Variant type should match").isEqualTo(expected.type()); + + if (expected.type() == Variants.PhysicalType.OBJECT) { + VariantObject expectedObject = expected.asObject(); + VariantObject actualObject = actual.asObject(); + assertThat(actualObject.numFields()) + .as("Variant object num fields should match") + .isEqualTo(expectedObject.numFields()); + for (String fieldName : expectedObject.fieldNames()) { + assertEqual(expectedObject.get(fieldName), actualObject.get(fieldName)); + } + + } else if (expected.type() == Variants.PhysicalType.ARRAY) { + VariantArray expectedArray = expected.asArray(); + VariantArray actualArray = actual.asArray(); + assertThat(actualArray.numElements()) + .as("Variant array num element should match") + .isEqualTo(expectedArray.numElements()); + for (int i = 0; i < expectedArray.numElements(); i += 1) { + assertEqual(expectedArray.get(i), actualArray.get(i)); + } + + } else { + assertThat(actual.asPrimitive().get()) + .as("Variant primitive value should match") + .isEqualTo(expected.asPrimitive().get()); + } + } + private static byte primitiveHeader(int primitiveType) { return (byte) (primitiveType << 2); } @@ -60,7 +107,11 @@ static SerializedPrimitive createString(String string) { return SerializedPrimitive.from(buffer, buffer.get(0)); } - static ByteBuffer createMetadata(Collection fieldNames, boolean sortNames) { + public static ByteBuffer emptyMetadata() { + return createMetadata(ImmutableList.of(), true); + } + + public static ByteBuffer createMetadata(Collection fieldNames, boolean sortNames) { if (fieldNames.isEmpty()) { return SerializedMetadata.EMPTY_V1_BUFFER; } @@ -108,7 +159,7 @@ static ByteBuffer createMetadata(Collection fieldNames, boolean sortName return buffer; } - static ByteBuffer createObject(ByteBuffer metadataBuffer, Map data) { + public static ByteBuffer createObject(ByteBuffer metadataBuffer, Map data) { // create the metadata to look up field names VariantMetadata metadata = Variants.metadata(metadataBuffer); diff --git a/parquet/src/main/java/org/apache/iceberg/data/parquet/BaseParquetReaders.java b/parquet/src/main/java/org/apache/iceberg/data/parquet/BaseParquetReaders.java index 70e6b3ff447e..7e972e6cedff 100644 --- a/parquet/src/main/java/org/apache/iceberg/data/parquet/BaseParquetReaders.java +++ b/parquet/src/main/java/org/apache/iceberg/data/parquet/BaseParquetReaders.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.data.parquet; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Optional; @@ -26,7 +27,9 @@ import org.apache.iceberg.parquet.ParquetSchemaUtil; import org.apache.iceberg.parquet.ParquetValueReader; import org.apache.iceberg.parquet.ParquetValueReaders; +import org.apache.iceberg.parquet.ParquetVariantVisitor; import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.parquet.VariantReaderBuilder; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; @@ -431,6 +434,16 @@ public ParquetValueReader primitive( } } + @Override + public ParquetValueReader variant(Types.VariantType iVariant, ParquetValueReader reader) { + return reader; + } + + @Override + public ParquetVariantVisitor> variantVisitor() { + return new VariantReaderBuilder(type, Arrays.asList(currentPath())); + } + MessageType type() { return type; } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetSchemaUtil.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetSchemaUtil.java index a0dc54c1cdd9..68a9aa979fdf 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetSchemaUtil.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetSchemaUtil.java @@ -27,6 +27,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; +import org.apache.parquet.io.InvalidRecordException; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.MessageType; import org.apache.parquet.schema.PrimitiveType; @@ -75,6 +76,32 @@ private static Schema convertInternal( converter.getAliases()); } + /** + * Returns true if the name identifies a field in the struct/group. + * + * @param group a GroupType + * @param name a String name + * @return true if the group contains a field with the given name + */ + public static boolean hasField(GroupType group, String name) { + return fieldType(group, name) != null; + } + + /** + * Returns the Type of the named field in the struct/group, or null. + * + * @param group a GroupType + * @param name a String name + * @return the Type of the field in the group, or null if it is not present. + */ + public static Type fieldType(GroupType group, String name) { + try { + return group.getType(name); + } catch (InvalidRecordException ignored) { + return null; + } + } + public static MessageType pruneColumns(MessageType fileSchema, Schema expectedSchema) { // column order must match the incoming type, so it doesn't matter that the ids are unordered Set selectedIds = TypeUtil.getProjectedIds(expectedSchema); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueReaders.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueReaders.java index 73ce83b9bfdd..bb8930085924 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueReaders.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetValueReaders.java @@ -63,6 +63,14 @@ public static ParquetValueReader unboxed(ColumnDescriptor desc) { return new UnboxedReader<>(desc); } + public static ParquetValueReader intsAsByte(ColumnDescriptor desc) { + return new IntAsByteReader(desc); + } + + public static ParquetValueReader intsAsShort(ColumnDescriptor desc) { + return new IntAsShortReader(desc); + } + public static ParquetValueReader strings(ColumnDescriptor desc) { return new StringReader(desc); } @@ -390,6 +398,28 @@ public String read(String reuse) { } } + private static class IntAsByteReader extends UnboxedReader { + private IntAsByteReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Byte read(Byte ignored) { + return (byte) readInteger(); + } + } + + private static class IntAsShortReader extends UnboxedReader { + private IntAsShortReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public Short read(Short ignored) { + return (short) readInteger(); + } + } + public static class IntAsLongReader extends UnboxedReader { public IntAsLongReader(ColumnDescriptor desc) { super(desc); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantReaders.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantReaders.java new file mode 100644 index 000000000000..10016d19d8c4 --- /dev/null +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantReaders.java @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.parquet; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import org.apache.commons.lang3.stream.Streams; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Iterables; +import org.apache.iceberg.variants.ShreddedObject; +import org.apache.iceberg.variants.Variant; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantObject; +import org.apache.iceberg.variants.VariantValue; +import org.apache.iceberg.variants.Variants; +import org.apache.iceberg.variants.Variants.PhysicalType; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.page.PageReadStore; + +public class ParquetVariantReaders { + private ParquetVariantReaders() {} + + public interface VariantValueReader extends ParquetValueReader { + @Override + default VariantValue read(VariantValue reuse) { + throw new UnsupportedOperationException("Variants must be read using read(VariantMetadata)"); + } + + /** Reads a variant value */ + VariantValue read(VariantMetadata metadata); + } + + private static final VariantValue MISSING = null; + + @SuppressWarnings("unchecked") + public static ParquetValueReader variant( + ParquetValueReader metadata, ParquetValueReader value) { + return new VariantReader( + (ParquetValueReader) metadata, (VariantValueReader) value); + } + + public static ParquetValueReader metadata(ColumnDescriptor desc) { + return new VariantMetadataReader(desc); + } + + public static VariantValueReader serialized(ColumnDescriptor desc) { + return new SerializedVariantReader(desc); + } + + public static VariantValueReader shredded( + int valueDL, + ParquetValueReader valueReader, + int typedDL, + ParquetValueReader typedReader) { + return new ShreddedVariantReader( + valueDL, (VariantValueReader) valueReader, typedDL, (VariantValueReader) typedReader); + } + + public static VariantValueReader objects( + int valueDL, + ParquetValueReader valueReader, + int fieldDL, + List fieldNames, + List fieldReaders) { + return new ShreddedObjectReader( + valueDL, (VariantValueReader) valueReader, fieldDL, fieldNames, fieldReaders); + } + + public static VariantValueReader asVariant(PhysicalType type, ParquetValueReader reader) { + return new ValueAsVariantReader<>(type, reader); + } + + private abstract static class DelegatingValueReader implements ParquetValueReader { + private final ParquetValueReader reader; + + private DelegatingValueReader(ParquetValueReader reader) { + this.reader = reader; + } + + protected S readFromDelegate(S reuse) { + return reader.read(reuse); + } + + @Override + public TripleIterator column() { + return reader.column(); + } + + @Override + public List> columns() { + return reader.columns(); + } + + @Override + public void setPageSource(PageReadStore pageStore) { + reader.setPageSource(pageStore); + } + } + + private static ByteBuffer readBinary(ColumnIterator column) { + ByteBuffer data = column.nextBinary().toByteBuffer(); + byte[] array = new byte[data.remaining()]; + data.get(array, 0, data.remaining()); + return ByteBuffer.wrap(array).order(ByteOrder.LITTLE_ENDIAN); + } + + private static class VariantMetadataReader + extends ParquetValueReaders.PrimitiveReader { + public VariantMetadataReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public VariantMetadata read(VariantMetadata reuse) { + return Variants.metadata(readBinary(column)); + } + } + + private static class SerializedVariantReader + extends ParquetValueReaders.PrimitiveReader implements VariantValueReader { + public SerializedVariantReader(ColumnDescriptor desc) { + super(desc); + } + + @Override + public VariantValue read(VariantMetadata metadata) { + return Variants.value(metadata, readBinary(column)); + } + } + + private static class ValueAsVariantReader extends DelegatingValueReader + implements VariantValueReader { + private final PhysicalType type; + + private ValueAsVariantReader(PhysicalType type, ParquetValueReader reader) { + super(reader); + this.type = type; + } + + @Override + public VariantValue read(VariantMetadata ignored) { + return Variants.of(type, readFromDelegate(null)); + } + } + + /** + * A Variant reader that combines value and typed_value columns from Parquet. + * + *

This reader does not handle merging partially shredded objects. To handle shredded objects, + * use {@link ShreddedObjectReader}. + */ + private static class ShreddedVariantReader implements VariantValueReader { + private final int valueDL; + private final VariantValueReader valueReader; + private final int typedDL; + private final VariantValueReader typedReader; + private final TripleIterator column; + private final List> children; + + public ShreddedVariantReader( + int valueDL, VariantValueReader valueReader, int typedDL, VariantValueReader typedReader) { + this.valueDL = valueDL; + this.valueReader = valueReader; + this.typedDL = typedDL; + this.typedReader = typedReader; + this.column = valueReader != null ? valueReader.column() : typedReader.column(); + this.children = children(valueReader, typedReader); + } + + @Override + public VariantValue read(VariantMetadata metadata) { + VariantValue value = ParquetVariantReaders.read(metadata, valueReader, valueDL); + VariantValue typed = ParquetVariantReaders.read(metadata, typedReader, typedDL); + + if (typed != null) { + return typed; + } + + return value; + } + + @Override + public TripleIterator column() { + return column; + } + + @Override + public List> columns() { + return children; + } + + @Override + public void setPageSource(PageReadStore pageStore) { + if (valueReader != null) { + valueReader.setPageSource(pageStore); + } + + if (typedReader != null) { + typedReader.setPageSource(pageStore); + } + } + } + + /** + * A Variant reader that combines value and partially shredded object columns. + * + *

This reader handles partially shredded objects. For shredded values, use {@link + * ShreddedVariantReader} instead. + */ + private static class ShreddedObjectReader implements VariantValueReader { + private final int valueDL; + private final VariantValueReader valueReader; + private final int fieldsDL; + private final String[] fieldNames; + private final VariantValueReader[] fieldReaders; + private final TripleIterator valueColumn; + private final TripleIterator fieldColumn; + private final List> children; + + public ShreddedObjectReader( + int valueDL, + VariantValueReader valueReader, + int fieldsDL, + List fieldNames, + List fieldReaders) { + this.valueDL = valueDL; + this.valueReader = valueReader; + this.fieldsDL = fieldsDL; + this.fieldNames = fieldNames.toArray(String[]::new); + this.fieldReaders = fieldReaders.toArray(VariantValueReader[]::new); + this.valueColumn = valueReader.column(); + this.fieldColumn = this.fieldReaders[0].column(); + this.children = children(Iterables.concat(ImmutableList.of(valueReader), fieldReaders)); + } + + @Override + public VariantValue read(VariantMetadata metadata) { + boolean isObject = fieldColumn.currentDefinitionLevel() > fieldsDL; + VariantValue value = ParquetVariantReaders.read(metadata, valueReader, valueDL); + + if (isObject) { + ShreddedObject object; + if (value == MISSING) { + object = Variants.object(metadata); + } else { + Preconditions.checkArgument( + value.type() == PhysicalType.OBJECT, + "Invalid variant, non-object value with shredded fields: %s", + value); + object = Variants.object(metadata, (VariantObject) value); + } + + for (int i = 0; i < fieldReaders.length; i += 1) { + // each field is a ShreddedVariantReader or ShreddedObjectReader that handles DL + String name = fieldNames[i]; + VariantValue fieldValue = fieldReaders[i].read(metadata); + if (fieldValue == MISSING) { + object.remove(name); + } else { + object.put(name, fieldValue); + } + } + + return object; + } + + // for non-objects, advance the field iterators + for (VariantValueReader reader : fieldReaders) { + for (TripleIterator child : reader.columns()) { + child.nextNull(); + } + } + + return value; + } + + @Override + public TripleIterator column() { + return valueColumn; + } + + @Override + public List> columns() { + return children; + } + + @Override + public void setPageSource(PageReadStore pageStore) { + valueReader.setPageSource(pageStore); + for (VariantValueReader reader : fieldReaders) { + reader.setPageSource(pageStore); + } + } + } + + private static class VariantReader implements ParquetValueReader { + private final ParquetValueReader metadataReader; + private final VariantValueReader valueReader; + private final TripleIterator column; + private final List> children; + + public VariantReader( + ParquetValueReader metadataReader, VariantValueReader valueReader) { + this.metadataReader = metadataReader; + this.valueReader = valueReader; + // metadata is always non-null so its column can be used for the variant + this.column = metadataReader.column(); + this.children = children(metadataReader, valueReader); + } + + @Override + public Variant read(Variant ignored) { + VariantMetadata metadata = metadataReader.read(null); + VariantValue value = valueReader.read(metadata); + if (value == MISSING) { + return new Variant() { + @Override + public VariantMetadata metadata() { + return metadata; + } + + @Override + public VariantValue value() { + return Variants.ofNull(); + } + }; + } + + return new Variant() { + @Override + public VariantMetadata metadata() { + return metadata; + } + + @Override + public VariantValue value() { + return value; + } + }; + } + + @Override + public TripleIterator column() { + return column; + } + + @Override + public List> columns() { + return children; + } + + @Override + public void setPageSource(PageReadStore pageStore) { + metadataReader.setPageSource(pageStore); + valueReader.setPageSource(pageStore); + } + } + + private static VariantValue read( + VariantMetadata metadata, VariantValueReader reader, int definitionLevel) { + if (reader != null) { + if (reader.column().currentDefinitionLevel() > definitionLevel) { + return reader.read(metadata); + } + + for (TripleIterator child : reader.columns()) { + child.nextNull(); + } + } + + return MISSING; + } + + private static List> children(ParquetValueReader... readers) { + return children(Arrays.asList(readers)); + } + + private static List> children(Iterable> readers) { + return ImmutableList.copyOf( + Iterables.concat( + Iterables.transform( + Streams.of(readers).filter(Objects::nonNull).collect(Collectors.toList()), + ParquetValueReader::columns))); + } +} diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantVisitor.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantVisitor.java new file mode 100644 index 000000000000..cabc0dc81e9c --- /dev/null +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetVariantVisitor.java @@ -0,0 +1,287 @@ +/* + * + * * Licensed to the Apache Software Foundation (ASF) under one + * * or more contributor license agreements. See the NOTICE file + * * distributed with this work for additional information + * * regarding copyright ownership. The ASF licenses this file + * * to you under the Apache License, Version 2.0 (the + * * "License"); you may not use this file except in compliance + * * with the License. You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, + * * software distributed under the License is distributed on an + * * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * * KIND, either express or implied. See the License for the + * * specific language governing permissions and limitations + * * under the License. + * + */ +package org.apache.iceberg.parquet; + +import java.util.List; +import org.apache.commons.compress.utils.Lists; +import org.apache.parquet.Preconditions; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation.ListLogicalTypeAnnotation; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; +import org.apache.parquet.schema.Type; + +public abstract class ParquetVariantVisitor { + static final String METADATA = "metadata"; + static final String VALUE = "value"; + static final String TYPED_VALUE = "typed_value"; + + /** + * Handles the root variant column group. + * + *

The value and typed_value results are combined by calling {@link #value}. + * + *

+   *   group v (VARIANT) { <-- metadata result and combined value and typed_value result
+   *     required binary metadata;
+   *     optional binary value;
+   *     optional ... typed_value;
+   *   }
+   * 
+ */ + public R variant(GroupType variant, R metadataResult, R valueResult) { + return null; + } + + /** + * Handles a serialized variant metadata column. + * + *
+   *   group v (VARIANT) {
+   *     required binary metadata; <-- this column
+   *     optional binary value;
+   *     optional ... typed_value;
+   *   }
+   * 
+ */ + public R metadata(PrimitiveType metadata) { + return null; + } + + /** + * Handles a serialized variant value column. + * + *
+   *   group variant_value_pair {
+   *     optional binary value; <-- this column
+   *     optional ... typed_value;
+   *   }
+   * 
+ */ + public R serialized(PrimitiveType value) { + return null; + } + + /** + * Handles a shredded primitive typed_value column. + * + *
+   *   group variant_value_pair {
+   *     optional binary value;
+   *     optional int32 typed_value; <-- this column when it is any primitive
+   *   }
+   * 
+ */ + public R primitive(PrimitiveType primitive) { + return null; + } + + /** + * Handles a variant value result and typed_value result pair. + * + *

The value and typed_value pair may be nested in an object field, array element, or in the + * root group of a variant. + * + *

This method is also called when the typed_value field is missing. + * + *

+   *   group variant_value_pair { <-- value result and typed_value result
+   *     optional binary value;
+   *     optional ... typed_value;
+   *   }
+   * 
+ */ + public R value(GroupType value, R valueResult, R typedResult) { + return null; + } + + /** + * Handles a shredded object value result and a list of field value results. + * + *

Each field's value and typed_value results are combined by calling {@link #value}. + * + *

+   *   group variant_value_pair {  <-- value result and typed_value field results
+   *     optional binary value;
+   *     optional group typed_value {
+   *       required group a {
+   *         optional binary value;
+   *         optional binary typed_value (UTF8);
+   *       }
+   *       ...
+   *     }
+   *   }
+   * 
+ */ + public R object(GroupType object, R valueResult, List fieldResults) { + return null; + } + + /** + * Handles a shredded array value result and an element value result. + * + *

The element's value and typed_value results are combined by calling {@link #value}. + * + *

+   *   group variant_value_pair {  <-- value result and element result
+   *     optional binary value;
+   *     optional group typed_value (LIST) {
+   *       repeated group list {
+   *         required group element {
+   *           optional binary value;
+   *           optional binary typed_value (UTF8);
+   *         }
+   *       }
+   *     }
+   *   }
+   * 
+ */ + public R array(GroupType array, R valueResult, R elementResult) { + return null; + } + + /** Handler called before visiting any primitive or group type. */ + public void beforeField(Type type) {} + + /** Handler called after visiting any primitive or group type. */ + public void afterField(Type type) {} + + public static R visit(GroupType type, ParquetVariantVisitor visitor) { + Preconditions.checkArgument( + ParquetSchemaUtil.hasField(type, METADATA), "Invalid variant, missing metadata: %s", type); + + Type metadataType = type.getType(METADATA); + Preconditions.checkArgument( + isBinary(metadataType), "Invalid variant metadata, expecting BINARY: %s", metadataType); + + R metadataResult = + withBeforeAndAfter( + () -> visitor.metadata(metadataType.asPrimitiveType()), metadataType, visitor); + R valueResult = visitValue(type, visitor); + + return visitor.variant(type, metadataResult, valueResult); + } + + public static R visitValue(GroupType valueGroup, ParquetVariantVisitor visitor) { + R valueResult; + if (ParquetSchemaUtil.hasField(valueGroup, VALUE)) { + Type valueType = valueGroup.getType(VALUE); + Preconditions.checkArgument( + isBinary(valueType), "Invalid variant value, expecting BINARY: %s", valueType); + + valueResult = + withBeforeAndAfter( + () -> visitor.serialized(valueType.asPrimitiveType()), valueType, visitor); + } else { + Preconditions.checkArgument( + ParquetSchemaUtil.hasField(valueGroup, TYPED_VALUE), + "Invalid variant, missing both value and typed_value: %s", + valueGroup); + + valueResult = null; + } + + if (ParquetSchemaUtil.hasField(valueGroup, TYPED_VALUE)) { + Type typedValueType = valueGroup.getType(TYPED_VALUE); + + if (typedValueType.isPrimitive()) { + R typedResult = + withBeforeAndAfter( + () -> visitor.primitive(typedValueType.asPrimitiveType()), typedValueType, visitor); + + return visitor.value(valueGroup, valueResult, typedResult); + + } else if (typedValueType.getLogicalTypeAnnotation() instanceof ListLogicalTypeAnnotation) { + R elementResult = + withBeforeAndAfter( + () -> visitArray(typedValueType.asGroupType(), visitor), typedValueType, visitor); + + return visitor.array(valueGroup, valueResult, elementResult); + + } else { + List results = + withBeforeAndAfter( + () -> visitObjectFields(typedValueType.asGroupType(), visitor), + typedValueType, + visitor); + + return visitor.object(valueGroup, valueResult, results); + } + } + + // there was no typed_value field, but the value result must be handled + return visitor.value(valueGroup, valueResult, null); + } + + private static R visitArray(GroupType array, ParquetVariantVisitor visitor) { + Preconditions.checkArgument( + array.getFieldCount() == 1, + "Invalid variant array: does not contain single repeated field: %s", + array); + + Type repeated = array.getFields().get(0); + Preconditions.checkArgument( + repeated.isRepetition(Type.Repetition.REPEATED), + "Invalid variant array: inner group is not repeated"); + + // 3-level structure is required; element is always the only child of the repeated field + return withBeforeAndAfter( + () -> visitElement(repeated.asGroupType().getType(0), visitor), repeated, visitor); + } + + private static R visitElement(Type element, ParquetVariantVisitor visitor) { + return withBeforeAndAfter(() -> visitValue(element.asGroupType(), visitor), element, visitor); + } + + private static List visitObjectFields(GroupType fields, ParquetVariantVisitor visitor) { + List results = Lists.newArrayList(); + for (Type fieldType : fields.getFields()) { + Preconditions.checkArgument( + !fieldType.isPrimitive(), "Invalid shredded object field, not a group: %s", fieldType); + R fieldResult = + withBeforeAndAfter( + () -> visitValue(fieldType.asGroupType(), visitor), fieldType, visitor); + results.add(fieldResult); + } + + return results; + } + + @FunctionalInterface + interface Action { + R invoke(); + } + + private static R withBeforeAndAfter( + Action task, Type type, ParquetVariantVisitor visitor) { + visitor.beforeField(type); + try { + return task.invoke(); + } finally { + visitor.afterField(type); + } + } + + private static boolean isBinary(Type type) { + return type.isPrimitive() + && type.asPrimitiveType().getPrimitiveTypeName() == PrimitiveTypeName.BINARY; + } +} diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/PruneColumns.java b/parquet/src/main/java/org/apache/iceberg/parquet/PruneColumns.java index d48485305e8a..83eba84b47d9 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/PruneColumns.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/PruneColumns.java @@ -26,6 +26,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.Types.ListType; import org.apache.iceberg.types.Types.MapType; +import org.apache.iceberg.types.Types.NestedField; import org.apache.iceberg.types.Types.StructType; import org.apache.parquet.schema.GroupType; import org.apache.parquet.schema.LogicalTypeAnnotation; @@ -57,7 +58,7 @@ public Type message(StructType expected, MessageType message, List fields) hasChange = true; builder.addField(field); } else { - if (isStruct(originalField)) { + if (isStruct(originalField, expected.field(fieldId))) { hasChange = true; builder.addField(originalField.asGroupType().withNewFields(Collections.emptyList())); } else { @@ -152,6 +153,11 @@ public Type map(MapType expected, GroupType map, Type key, Type value) { return null; } + @Override + public Type variant(org.apache.iceberg.types.Types.VariantType expected, Type variant) { + return variant; + } + @Override public Type primitive( org.apache.iceberg.types.Type.PrimitiveType expected, PrimitiveType primitive) { @@ -162,9 +168,11 @@ private Integer getId(Type type) { return type.getId() == null ? null : type.getId().intValue(); } - private boolean isStruct(Type field) { + private boolean isStruct(Type field, NestedField expected) { if (field.isPrimitive()) { return false; + } else if (expected.type() == org.apache.iceberg.types.Types.VariantType.get()) { + return false; } else { GroupType groupType = field.asGroupType(); LogicalTypeAnnotation logicalTypeAnnotation = groupType.getLogicalTypeAnnotation(); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/TypeToMessageType.java b/parquet/src/main/java/org/apache/iceberg/parquet/TypeToMessageType.java index 54f11500489b..57e125ed8ff5 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/TypeToMessageType.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/TypeToMessageType.java @@ -86,6 +86,9 @@ public Type field(NestedField field) { if (field.type().isPrimitiveType()) { return primitive(field.type().asPrimitiveType(), repetition, id, name); + } else if (field.type() instanceof org.apache.iceberg.types.Types.VariantType) { + return variant(repetition, id, name); + } else { NestedType nested = field.type().asNestedType(); if (nested.isStructType()) { @@ -117,6 +120,14 @@ public GroupType map(MapType map, Type.Repetition repetition, int id, String nam .named(AvroSchemaUtil.makeCompatibleName(name)); } + public Type variant(Type.Repetition repetition, int id, String originalName) { + String name = AvroSchemaUtil.makeCompatibleName(originalName); + return Types.buildGroup(repetition).id(id) + .required(BINARY).named("metadata") + .required(BINARY).named("value") + .named(name); + } + public Type primitive( PrimitiveType primitive, Type.Repetition repetition, int id, String originalName) { String name = AvroSchemaUtil.makeCompatibleName(originalName); diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/TypeWithSchemaVisitor.java b/parquet/src/main/java/org/apache/iceberg/parquet/TypeWithSchemaVisitor.java index e0c07d31755e..f327743ea92e 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/TypeWithSchemaVisitor.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/TypeWithSchemaVisitor.java @@ -22,10 +22,13 @@ import java.util.List; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.types.Type.TypeID; import org.apache.iceberg.types.Types; import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.ListLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.MapLogicalTypeAnnotation; import org.apache.parquet.schema.MessageType; -import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; @@ -54,94 +57,14 @@ public static T visit( } else { // if not a primitive, the typeId must be a group GroupType group = type.asGroupType(); - OriginalType annotation = group.getOriginalType(); - if (annotation != null) { - switch (annotation) { - case LIST: - Preconditions.checkArgument( - group.getFieldCount() == 1, - "Invalid list: does not contain single repeated field: %s", - group); - - Type repeatedElement = group.getFields().get(0); - Preconditions.checkArgument( - repeatedElement.isRepetition(Type.Repetition.REPEATED), - "Invalid list: inner group is not repeated"); - - Type listElement = ParquetSchemaUtil.determineListElementType(group); - Types.ListType list = null; - Types.NestedField element = null; - if (iType != null) { - list = iType.asListType(); - element = list.fields().get(0); - } - - if (listElement.isRepetition(Type.Repetition.REPEATED)) { - return visitTwoLevelList(list, element, group, listElement, visitor); - } else { - return visitThreeLevelList(list, element, group, listElement, visitor); - } - - case MAP: - Preconditions.checkArgument( - !group.isRepetition(Type.Repetition.REPEATED), - "Invalid map: top-level group is repeated: %s", - group); - Preconditions.checkArgument( - group.getFieldCount() == 1, - "Invalid map: does not contain single repeated field: %s", - group); - - GroupType repeatedKeyValue = group.getType(0).asGroupType(); - Preconditions.checkArgument( - repeatedKeyValue.isRepetition(Type.Repetition.REPEATED), - "Invalid map: inner group is not repeated"); - Preconditions.checkArgument( - repeatedKeyValue.getFieldCount() <= 2, - "Invalid map: repeated group does not have 2 fields"); - - Types.MapType map = null; - Types.NestedField keyField = null; - Types.NestedField valueField = null; - if (iType != null) { - map = iType.asMapType(); - keyField = map.fields().get(0); - valueField = map.fields().get(1); - } - - visitor.fieldNames.push(repeatedKeyValue.getName()); - try { - T keyResult = null; - T valueResult = null; - switch (repeatedKeyValue.getFieldCount()) { - case 2: - // if there are 2 fields, both key and value are projected - keyResult = visitField(keyField, repeatedKeyValue.getType(0), visitor); - valueResult = visitField(valueField, repeatedKeyValue.getType(1), visitor); - break; - case 1: - // if there is just one, use the name to determine what it is - Type keyOrValue = repeatedKeyValue.getType(0); - if (keyOrValue.getName().equalsIgnoreCase("key")) { - keyResult = visitField(keyField, keyOrValue, visitor); - // value result remains null - } else { - valueResult = visitField(valueField, keyOrValue, visitor); - // key result remains null - } - break; - default: - // both results will remain null - } - - return visitor.map(map, group, keyResult, valueResult); - - } finally { - visitor.fieldNames.pop(); - } - - default: - } + LogicalTypeAnnotation annotation = group.getLogicalTypeAnnotation(); + if (annotation instanceof ListLogicalTypeAnnotation) { + return visitList(iType, group, visitor); + } else if (annotation instanceof MapLogicalTypeAnnotation) { + return visitMap(iType, group, visitor); + } else if (iType != null && iType.typeId() == TypeID.VARIANT) { + // when Parquet has a VARIANT logical type, use it here + return visitVariant((Types.VariantType) iType, group, visitor); } Types.StructType struct = iType != null ? iType.asStructType() : null; @@ -149,6 +72,93 @@ public static T visit( } } + private static T visitList( + org.apache.iceberg.types.Type iType, GroupType group, TypeWithSchemaVisitor visitor) { + Preconditions.checkArgument( + group.getFieldCount() == 1, + "Invalid list: does not contain single repeated field: %s", + group); + + Type repeatedElement = group.getFields().get(0); + Preconditions.checkArgument( + repeatedElement.isRepetition(Type.Repetition.REPEATED), + "Invalid list: inner group is not repeated"); + + Type listElement = ParquetSchemaUtil.determineListElementType(group); + Types.ListType list = null; + Types.NestedField element = null; + if (iType != null) { + list = iType.asListType(); + element = list.fields().get(0); + } + + if (listElement.isRepetition(Type.Repetition.REPEATED)) { + return visitTwoLevelList(list, element, group, listElement, visitor); + } else { + return visitThreeLevelList(list, element, group, listElement, visitor); + } + } + + private static T visitMap( + org.apache.iceberg.types.Type iType, GroupType group, TypeWithSchemaVisitor visitor) { + Preconditions.checkArgument( + !group.isRepetition(Type.Repetition.REPEATED), + "Invalid map: top-level group is repeated: %s", + group); + Preconditions.checkArgument( + group.getFieldCount() == 1, + "Invalid map: does not contain single repeated field: %s", + group); + + GroupType repeatedKeyValue = group.getType(0).asGroupType(); + Preconditions.checkArgument( + repeatedKeyValue.isRepetition(Type.Repetition.REPEATED), + "Invalid map: inner group is not repeated"); + Preconditions.checkArgument( + repeatedKeyValue.getFieldCount() <= 2, + "Invalid map: repeated group does not have 2 fields"); + + Types.MapType map = null; + Types.NestedField keyField = null; + Types.NestedField valueField = null; + if (iType != null) { + map = iType.asMapType(); + keyField = map.fields().get(0); + valueField = map.fields().get(1); + } + + visitor.fieldNames.push(repeatedKeyValue.getName()); + try { + T keyResult = null; + T valueResult = null; + switch (repeatedKeyValue.getFieldCount()) { + case 2: + // if there are 2 fields, both key and value are projected + keyResult = visitField(keyField, repeatedKeyValue.getType(0), visitor); + valueResult = visitField(valueField, repeatedKeyValue.getType(1), visitor); + break; + case 1: + // if there is just one, use the name to determine what it is + Type keyOrValue = repeatedKeyValue.getType(0); + if (keyOrValue.getName().equalsIgnoreCase("key")) { + keyResult = visitField(keyField, keyOrValue, visitor); + // value result remains null + } else { + valueResult = visitField(valueField, keyOrValue, visitor); + // key result remains null + } + break; + default: + // both results will remain null + } + + return visitor.map(map, group, keyResult, valueResult); + + } finally { + visitor.fieldNames.pop(); + } + } + private static T visitTwoLevelList( Types.ListType iListType, Types.NestedField iListElement, @@ -201,6 +211,17 @@ private static List visitFields( return results; } + private static T visitVariant( + Types.VariantType variant, GroupType group, TypeWithSchemaVisitor visitor) { + ParquetVariantVisitor variantVisitor = visitor.variantVisitor(); + if (variantVisitor != null) { + T variantResult = ParquetVariantVisitor.visit(group, variantVisitor); + return visitor.variant(variant, variantResult); + } else { + return visitor.variant(variant, null); + } + } + public T message(Types.StructType iStruct, MessageType message, List fields) { return null; } @@ -217,11 +238,19 @@ public T map(Types.MapType iMap, GroupType map, T key, T value) { return null; } + public T variant(Types.VariantType iVariant, T result) { + throw new UnsupportedOperationException("Not implemented for variant"); + } + public T primitive( org.apache.iceberg.types.Type.PrimitiveType iPrimitive, PrimitiveType primitive) { return null; } + public ParquetVariantVisitor variantVisitor() { + return null; + } + protected String[] currentPath() { return Lists.newArrayList(fieldNames.descendingIterator()).toArray(new String[0]); } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VariantReaderBuilder.java b/parquet/src/main/java/org/apache/iceberg/parquet/VariantReaderBuilder.java new file mode 100644 index 000000000000..c308dde0d116 --- /dev/null +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VariantReaderBuilder.java @@ -0,0 +1,268 @@ +/* + * + * * Licensed to the Apache Software Foundation (ASF) under one + * * or more contributor license agreements. See the NOTICE file + * * distributed with this work for additional information + * * regarding copyright ownership. The ASF licenses this file + * * to you under the Apache License, Version 2.0 (the + * * "License"); you may not use this file except in compliance + * * with the License. You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, + * * software distributed under the License is distributed on an + * * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * * KIND, either express or implied. See the License for the + * * specific language governing permissions and limitations + * * under the License. + * + */ +package org.apache.iceberg.parquet; + +import java.util.Deque; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.iceberg.parquet.ParquetVariantReaders.VariantValueReader; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.variants.Variants; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.IntLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.LogicalTypeAnnotationVisitor; +import org.apache.parquet.schema.LogicalTypeAnnotation.StringLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +public class VariantReaderBuilder extends ParquetVariantVisitor> { + private final MessageType schema; + private final Iterable basePath; + private final Deque fieldNames = Lists.newLinkedList(); + + public VariantReaderBuilder(MessageType schema, Iterable basePath) { + this.schema = schema; + this.basePath = basePath; + } + + @Override + public void beforeField(Type type) { + fieldNames.addLast(type.getName()); + } + + @Override + public void afterField(Type type) { + fieldNames.removeLast(); + } + + private String[] currentPath() { + return Streams.concat(Streams.stream(basePath), fieldNames.stream()).toArray(String[]::new); + } + + private String[] path(String name) { + return Streams.concat(Streams.stream(basePath), fieldNames.stream(), Stream.of(name)) + .toArray(String[]::new); + } + + @Override + public ParquetValueReader variant( + GroupType variant, ParquetValueReader metadataReader, ParquetValueReader valueReader) { + return ParquetVariantReaders.variant(metadataReader, valueReader); + } + + @Override + public ParquetValueReader metadata(PrimitiveType metadata) { + ColumnDescriptor desc = schema.getColumnDescription(currentPath()); + return ParquetVariantReaders.metadata(desc); + } + + @Override + public VariantValueReader serialized(PrimitiveType value) { + ColumnDescriptor desc = schema.getColumnDescription(currentPath()); + return ParquetVariantReaders.serialized(desc); + } + + @Override + public VariantValueReader primitive(PrimitiveType primitive) { + ColumnDescriptor desc = schema.getColumnDescription(currentPath()); + + if (primitive.getLogicalTypeAnnotation() != null) { + Optional reader = + primitive.getLogicalTypeAnnotation().accept(new LogicalTypeToVariantReader(desc)); + if (reader.isPresent()) { + return reader.get(); + } + + } else { + switch (primitive.getPrimitiveTypeName()) { + case BINARY: + return ParquetVariantReaders.asVariant( + Variants.PhysicalType.BINARY, ParquetValueReaders.byteBuffers(desc)); + case BOOLEAN: + // the actual boolean type will be fixed in PrimitiveWrapper + return ParquetVariantReaders.asVariant( + Variants.PhysicalType.BOOLEAN_TRUE, ParquetValueReaders.unboxed(desc)); + case INT32: + return ParquetVariantReaders.asVariant( + Variants.PhysicalType.INT32, ParquetValueReaders.unboxed(desc)); + case INT64: + return ParquetVariantReaders.asVariant( + Variants.PhysicalType.INT64, ParquetValueReaders.unboxed(desc)); + case FLOAT: + return ParquetVariantReaders.asVariant( + Variants.PhysicalType.FLOAT, ParquetValueReaders.unboxed(desc)); + case DOUBLE: + return ParquetVariantReaders.asVariant( + Variants.PhysicalType.DOUBLE, ParquetValueReaders.unboxed(desc)); + } + } + + // note that both FIXED_LEN_BYTE_ARRAY and INT96 are not valid Variant primitives + throw new UnsupportedOperationException("Unsupported shredded value type: " + primitive); + } + + @Override + public VariantValueReader value( + GroupType group, ParquetValueReader valueReader, ParquetValueReader typedReader) { + int valueDL = + valueReader != null + ? schema.getMaxDefinitionLevel(path(ParquetVariantVisitor.VALUE)) - 1 + : Integer.MAX_VALUE; + int typedDL = + typedReader != null + ? schema.getMaxDefinitionLevel(path(ParquetVariantVisitor.TYPED_VALUE)) - 1 + : Integer.MAX_VALUE; + return ParquetVariantReaders.shredded(valueDL, valueReader, typedDL, typedReader); + } + + @Override + public VariantValueReader object( + GroupType group, + ParquetValueReader valueReader, + List> fieldResults) { + // TODO: if fields are required, set DL to 0. Or maybe these values work? + int valueDL = schema.getMaxDefinitionLevel(path(ParquetVariantVisitor.VALUE)) - 1; + int fieldsDL = schema.getMaxDefinitionLevel(path(ParquetVariantVisitor.TYPED_VALUE)) - 1; + + List shreddedFieldNames = + group.getType(ParquetVariantVisitor.TYPED_VALUE).asGroupType().getFields().stream() + .map(Type::getName) + .collect(Collectors.toList()); + List fieldReaders = + fieldResults.stream().map(VariantValueReader.class::cast).collect(Collectors.toList()); + + return ParquetVariantReaders.objects( + valueDL, valueReader, fieldsDL, shreddedFieldNames, fieldReaders); + } + + @Override + public VariantValueReader array( + GroupType array, ParquetValueReader valueResult, ParquetValueReader elementResult) { + throw new UnsupportedOperationException("Array is not yet supported"); + } + + private static class LogicalTypeToVariantReader + implements LogicalTypeAnnotationVisitor { + private final ColumnDescriptor desc; + + private LogicalTypeToVariantReader(ColumnDescriptor desc) { + this.desc = desc; + } + + @Override + public Optional visit(StringLogicalTypeAnnotation ignored) { + VariantValueReader reader = + ParquetVariantReaders.asVariant( + Variants.PhysicalType.STRING, ParquetValueReaders.strings(desc)); + + return Optional.of(reader); + } + + @Override + public Optional visit(DecimalLogicalTypeAnnotation logical) { + Variants.PhysicalType variantType = variantDecimalType(desc.getPrimitiveType()); + VariantValueReader reader = + ParquetVariantReaders.asVariant(variantType, ParquetValueReaders.bigDecimals(desc)); + + return Optional.of(reader); + } + + @Override + public Optional visit(DateLogicalTypeAnnotation ignored) { + VariantValueReader reader = + ParquetVariantReaders.asVariant( + Variants.PhysicalType.DATE, ParquetValueReaders.unboxed(desc)); + + return Optional.of(reader); + } + + @Override + public Optional visit(TimestampLogicalTypeAnnotation logical) { + Variants.PhysicalType variantType = + logical.isAdjustedToUTC() + ? Variants.PhysicalType.TIMESTAMPTZ + : Variants.PhysicalType.TIMESTAMPNTZ; + + VariantValueReader reader = + ParquetVariantReaders.asVariant(variantType, ParquetValueReaders.timestamps(desc)); + + return Optional.of(reader); + } + + @Override + public Optional visit(IntLogicalTypeAnnotation logical) { + if (!logical.isSigned()) { + // unsigned ints are not allowed for shredded fields + throw new UnsupportedOperationException("Unsupported shredded value type: " + logical); + } + + VariantValueReader reader; + switch (logical.getBitWidth()) { + case 64: + reader = + ParquetVariantReaders.asVariant( + Variants.PhysicalType.INT64, ParquetValueReaders.unboxed(desc)); + break; + case 32: + reader = + ParquetVariantReaders.asVariant( + Variants.PhysicalType.INT32, ParquetValueReaders.unboxed(desc)); + break; + case 16: + reader = + ParquetVariantReaders.asVariant( + Variants.PhysicalType.INT16, ParquetValueReaders.intsAsShort(desc)); + break; + case 8: + reader = + ParquetVariantReaders.asVariant( + Variants.PhysicalType.INT8, ParquetValueReaders.intsAsByte(desc)); + break; + default: + throw new IllegalArgumentException("Invalid bit width for int: " + logical.getBitWidth()); + } + + return Optional.of(reader); + } + + private static Variants.PhysicalType variantDecimalType(PrimitiveType primitive) { + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return Variants.PhysicalType.DECIMAL16; + case INT64: + return Variants.PhysicalType.DECIMAL8; + case INT32: + return Variants.PhysicalType.DECIMAL4; + } + + throw new IllegalArgumentException("Invalid primitive type for decimal: " + primitive); + } + } +} diff --git a/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantReaders.java b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantReaders.java new file mode 100644 index 000000000000..15c704b0bda1 --- /dev/null +++ b/parquet/src/test/java/org/apache/iceberg/parquet/TestVariantReaders.java @@ -0,0 +1,1071 @@ +/* + * + * * Licensed to the Apache Software Foundation (ASF) under one + * * or more contributor license agreements. See the NOTICE file + * * distributed with this work for additional information + * * regarding copyright ownership. The ASF licenses this file + * * to you under the Apache License, Version 2.0 (the + * * "License"); you may not use this file except in compliance + * * with the License. You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, + * * software distributed under the License is distributed on an + * * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * * KIND, either express or implied. See the License for the + * * specific language governing permissions and limitations + * * under the License. + * + */ + +package org.apache.iceberg.parquet; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericRecord; +import org.apache.hadoop.conf.Configuration; +import org.apache.iceberg.Schema; +import org.apache.iceberg.data.Record; +import org.apache.iceberg.data.parquet.InternalReader; +import org.apache.iceberg.inmemory.InMemoryOutputFile; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.types.Types.IntegerType; +import org.apache.iceberg.types.Types.NestedField; +import org.apache.iceberg.types.Types.VariantType; +import org.apache.iceberg.variants.ShreddedObject; +import org.apache.iceberg.variants.Variant; +import org.apache.iceberg.variants.VariantMetadata; +import org.apache.iceberg.variants.VariantObject; +import org.apache.iceberg.variants.VariantPrimitive; +import org.apache.iceberg.variants.VariantTestUtil; +import org.apache.iceberg.variants.VariantValue; +import org.apache.iceberg.variants.Variants; +import org.apache.parquet.Preconditions; +import org.apache.parquet.avro.AvroSchemaConverter; +import org.apache.parquet.avro.AvroWriteSupport; +import org.apache.parquet.hadoop.ParquetWriter; +import org.apache.parquet.hadoop.api.WriteSupport; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.LogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName; +import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.Types; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.FieldSource; +import org.junit.jupiter.params.provider.MethodSource; + +public class TestVariantReaders { + private static final Schema SCHEMA = + new Schema( + NestedField.required(1, "id", IntegerType.get()), + NestedField.required(2, "var", VariantType.get())); + + private static final LogicalTypeAnnotation STRING = LogicalTypeAnnotation.stringType(); + + private static final ByteBuffer TEST_METADATA_BUFFER = + VariantTestUtil.createMetadata(ImmutableList.of("a", "b", "c", "d", "e"), true); + private static final ByteBuffer TEST_OBJECT_BUFFER = + VariantTestUtil.createObject( + TEST_METADATA_BUFFER, + ImmutableMap.of( + "a", Variants.ofNull(), + "d", Variants.of(Variants.PhysicalType.STRING, "iceberg"))); + + private static final VariantMetadata EMPTY_METADATA = + Variants.metadata(VariantTestUtil.emptyMetadata()); + private static final VariantMetadata TEST_METADATA = Variants.metadata(TEST_METADATA_BUFFER); + private static final VariantObject TEST_OBJECT = + (VariantObject) Variants.value(TEST_METADATA, TEST_OBJECT_BUFFER); + + private static final VariantPrimitive[] PRIMITIVES = + new VariantPrimitive[] { + Variants.ofNull(), + Variants.of(true), + Variants.of(false), + Variants.of((byte) 34), + Variants.of((byte) -34), + Variants.of((short) 1234), + Variants.of((short) -1234), + Variants.of(12345), + Variants.of(-12345), + Variants.of(9876543210L), + Variants.of(-9876543210L), + Variants.of(10.11F), + Variants.of(-10.11F), + Variants.of(14.3D), + Variants.of(-14.3D), + Variants.ofIsoDate("2024-11-07"), + Variants.ofIsoDate("1957-11-07"), + Variants.ofIsoTimestamptz("2024-11-07T12:33:54.123456+00:00"), + Variants.ofIsoTimestamptz("1957-11-07T12:33:54.123456+00:00"), + Variants.ofIsoTimestampntz("2024-11-07T12:33:54.123456"), + Variants.ofIsoTimestampntz("1957-11-07T12:33:54.123456"), + Variants.of(new BigDecimal("123456.7890")), // decimal4 + Variants.of(new BigDecimal("-123456.7890")), // decimal4 + Variants.of(new BigDecimal("1234567890.987654321")), // decimal8 + Variants.of(new BigDecimal("-1234567890.987654321")), // decimal8 + Variants.of(new BigDecimal("9876543210.123456789")), // decimal16 + Variants.of(new BigDecimal("-9876543210.123456789")), // decimal16 + Variants.of(ByteBuffer.wrap(new byte[] {0x0a, 0x0b, 0x0c, 0x0d})), + Variants.of("iceberg"), + }; + + private static Stream metadataAndValues() { + Stream primitives = + Stream.of(PRIMITIVES).map(variant -> Arguments.of(EMPTY_METADATA, variant)); + Stream object = Stream.of(Arguments.of(TEST_METADATA, TEST_OBJECT)); + return Streams.concat(primitives, object); + } + + @ParameterizedTest + @MethodSource("metadataAndValues") + public void testUnshreddedVariants(VariantMetadata metadata, VariantValue expected) + throws IOException { + GroupType variantType = variant("var", 2); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord variant = + record(variantType, Map.of("metadata", metadata.buffer(), "value", serialize(expected))); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(metadata, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @ParameterizedTest + @MethodSource("metadataAndValues") + public void testUnshreddedVariantsWithShreddedSchema( + VariantMetadata metadata, VariantValue expected) throws IOException { + // the variant's Parquet schema has a shredded field that is unused by all data values + GroupType variantType = variant("var", 2, shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord variant = + record(variantType, Map.of("metadata", metadata.buffer(), "value", serialize(expected))); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(metadata, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @ParameterizedTest + @FieldSource("PRIMITIVES") + public void testShreddedVariantPrimitives(VariantPrimitive primitive) throws IOException { + Assumptions.assumeThat(primitive.type() != Variants.PhysicalType.NULL) + .as("Null is not a shredded type") + .isTrue(); + + GroupType variantType = variant("var", 2, shreddedType(primitive)); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord variant = + record( + variantType, + Map.of( + "metadata", + VariantTestUtil.emptyMetadata(), + "typed_value", + toAvroValue(primitive))); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(EMPTY_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(primitive, actualVariant.value()); + } + + @Test + public void testNullValueAndNullTypedValue() throws IOException { + GroupType variantType = variant("var", 2, shreddedPrimitive(PrimitiveTypeName.INT32)); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord variant = + record(variantType, Map.of("metadata", VariantTestUtil.emptyMetadata())); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(EMPTY_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(Variants.ofNull(), actualVariant.value()); + } + + @Test + public void testMissingValueColumn() throws IOException { + GroupType variantType = + Types.buildGroup(Type.Repetition.REQUIRED) + .id(2) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .addField(shreddedPrimitive(PrimitiveTypeName.INT32)) + .named("var"); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord variant = + record(variantType, Map.of("metadata", VariantTestUtil.emptyMetadata(), "typed_value", 34)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(EMPTY_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(Variants.of(34), actualVariant.value()); + } + + @Test + public void testValueAndTypedValueConflict() throws IOException { + GroupType variantType = variant("var", 2, shreddedPrimitive(PrimitiveTypeName.INT32)); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord variant = + record( + variantType, + Map.of( + "metadata", + VariantTestUtil.emptyMetadata(), + "value", + serialize(Variants.of("str")), + "typed_value", + 34)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(EMPTY_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(Variants.of(34), actualVariant.value()); + } + + @Test + public void testUnsignedInteger() { + GroupType variantType = + variant( + "var", + 2, + shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(32, false))); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord variant = + record(variantType, Map.of("metadata", VariantTestUtil.emptyMetadata())); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + assertThatThrownBy(() -> writeAndRead(parquetSchema, record)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessage("Unsupported shredded value type: INTEGER(32,false)"); + } + + @Test + public void testFixedLengthByteArray() { + GroupType variantType = + variant( + "var", + 2, + Types.optional(PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY).length(4).named("typed_value")); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord variant = + record(variantType, Map.of("metadata", VariantTestUtil.emptyMetadata())); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + assertThatThrownBy(() -> writeAndRead(parquetSchema, record)) + .isInstanceOf(UnsupportedOperationException.class) + .hasMessage( + "Unsupported shredded value type: optional fixed_len_byte_array(4) typed_value"); + } + + @Test + public void testShreddedObject() throws IOException { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord a = record(fieldA, Map.of("value", serialize(Variants.ofNull()))); + GenericRecord b = record(fieldB, Map.of("typed_value", "")); + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + ShreddedObject expected = Variants.object(TEST_METADATA); + expected.put("a", Variants.ofNull()); + expected.put("b", Variants.of("")); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @Test + public void testShreddedObjectMissingField() throws IOException { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord a = record(fieldA, Map.of("value", serialize(Variants.of(false)))); + // value and typed_value are null, but a struct for b is required + GenericRecord b = record(fieldB, Map.of()); + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + ShreddedObject expected = Variants.object(TEST_METADATA); + expected.put("a", Variants.of(false)); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @Test + public void testEmptyShreddedObject() throws IOException { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord a = record(fieldA, Map.of()); // missing + GenericRecord b = record(fieldB, Map.of()); // missing + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + ShreddedObject expected = Variants.object(TEST_METADATA); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @Test + public void testShreddedObjectMissingFieldValueColumn() throws IOException { + // field groups do not have value + GroupType fieldA = + Types.buildGroup(Type.Repetition.REQUIRED) + .addField(shreddedPrimitive(PrimitiveTypeName.INT32)) + .named("a"); + GroupType fieldB = + Types.buildGroup(Type.Repetition.REQUIRED) + .addField(shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)) + .named("b"); + GroupType objectFields = + Types.buildGroup(Type.Repetition.OPTIONAL).addFields(fieldA, fieldB).named("typed_value"); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord a = record(fieldA, Map.of()); // typed_value=null + GenericRecord b = record(fieldB, Map.of("typed_value", "iceberg")); + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + ShreddedObject expected = Variants.object(TEST_METADATA); + expected.put("b", Variants.of("iceberg")); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @Test + public void testShreddedObjectMissingTypedValue() throws IOException { + // field groups do not have typed_value + GroupType fieldA = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .named("a"); + GroupType fieldB = + Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .named("b"); + GroupType objectFields = + Types.buildGroup(Type.Repetition.OPTIONAL).addFields(fieldA, fieldB).named("typed_value"); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord a = record(fieldA, Map.of()); // value=null + GenericRecord b = record(fieldB, Map.of("value", serialize(Variants.of("iceberg")))); + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + ShreddedObject expected = Variants.object(TEST_METADATA); + expected.put("b", Variants.of("iceberg")); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @Test + public void testShreddedObjectWithinShreddedObject() throws IOException { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType innerFields = objectFields(fieldA, fieldB); + GroupType fieldC = field("c", innerFields); + GroupType fieldD = field("d", shreddedPrimitive(PrimitiveTypeName.DOUBLE)); + GroupType outerFields = objectFields(fieldC, fieldD); + GroupType variantType = variant("var", 2, outerFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord a = record(fieldA, Map.of("typed_value", 34)); + GenericRecord b = record(fieldB, Map.of("typed_value", "iceberg")); + GenericRecord inner = record(innerFields, Map.of("a", a, "b", b)); + GenericRecord c = record(fieldC, Map.of("typed_value", inner)); + GenericRecord d = record(fieldD, Map.of("typed_value", -0.0D)); + GenericRecord outer = record(outerFields, Map.of("c", c, "d", d)); + GenericRecord variant = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", outer)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + ShreddedObject expectedInner = Variants.object(TEST_METADATA); + expectedInner.put("a", Variants.of(34)); + expectedInner.put("b", Variants.of("iceberg")); + ShreddedObject expectedOuter = Variants.object(TEST_METADATA); + expectedOuter.put("c", expectedInner); + expectedOuter.put("d", Variants.of(-0.0D)); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expectedOuter, actualVariant.value()); + } + + @Test + public void testShreddedObjectWithOptionalFieldStructs() throws IOException { + // fields use an incorrect OPTIONAL struct of value and typed_value to test definition levels + GroupType fieldA = + Types.buildGroup(Type.Repetition.OPTIONAL) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedPrimitive(PrimitiveTypeName.INT32)) + .named("a"); + GroupType fieldB = + Types.buildGroup(Type.Repetition.OPTIONAL) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)) + .named("b"); + GroupType fieldC = + Types.buildGroup(Type.Repetition.OPTIONAL) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedPrimitive(PrimitiveTypeName.DOUBLE)) + .named("c"); + GroupType fieldD = + Types.buildGroup(Type.Repetition.OPTIONAL) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedPrimitive(PrimitiveTypeName.BOOLEAN)) + .named("d"); + GroupType objectFields = + Types.buildGroup(Type.Repetition.OPTIONAL) + .addFields(fieldA, fieldB, fieldC, fieldD) + .named("typed_value"); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord a = record(fieldA, Map.of("value", serialize(Variants.of(34)))); + GenericRecord b = record(fieldB, Map.of("typed_value", "iceberg")); + GenericRecord c = record(fieldC, Map.of()); // c.value and c.typed_value are missing + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b, "c", c)); // d is missing + GenericRecord variant = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + // the expected value is the shredded field value + ShreddedObject expected = Variants.object(TEST_METADATA); + expected.put("a", Variants.of(34)); + expected.put("b", Variants.of("iceberg")); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @Test + public void testPartiallyShreddedObject() throws IOException { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + ShreddedObject baseObject = Variants.object(TEST_METADATA); + baseObject.put("d", Variants.ofIsoDate("2024-01-30")); + + GenericRecord a = record(fieldA, Map.of("value", serialize(Variants.ofNull()))); + GenericRecord b = record(fieldB, Map.of("typed_value", "iceberg")); + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record( + variantType, + Map.of( + "metadata", + TEST_METADATA_BUFFER, + "value", + serialize(baseObject), + "typed_value", + fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + ShreddedObject expected = Variants.object(TEST_METADATA); + expected.put("a", Variants.ofNull()); + expected.put("b", Variants.of("iceberg")); + expected.put("d", Variants.ofIsoDate("2024-01-30")); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @Test + public void testPartiallyShreddedObjectFieldConflict() throws IOException { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + ByteBuffer baseObjectBuffer = + VariantTestUtil.createObject( + TEST_METADATA_BUFFER, Map.of("b", Variants.ofIsoDate("2024-01-30"))); // conflict + + GenericRecord a = record(fieldA, Map.of("value", serialize(Variants.ofNull()))); + GenericRecord b = record(fieldB, Map.of("typed_value", "iceberg")); + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record( + variantType, + Map.of( + "metadata", + TEST_METADATA_BUFFER, + "value", + baseObjectBuffer, + "typed_value", + fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + // the expected value is the shredded field value + ShreddedObject expected = Variants.object(TEST_METADATA); + expected.put("a", Variants.ofNull()); + expected.put("b", Variants.of("iceberg")); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @Test + public void testPartiallyShreddedObjectMissingFieldConflict() throws IOException { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + ByteBuffer baseObjectBuffer = + VariantTestUtil.createObject( + TEST_METADATA_BUFFER, Map.of("b", Variants.ofIsoDate("2024-01-30"))); // conflict + + GenericRecord a = record(fieldA, Map.of("value", serialize(Variants.ofNull()))); + // value and typed_value are null, but a struct for b is required + GenericRecord b = record(fieldB, Map.of()); + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record( + variantType, + Map.of( + "metadata", + TEST_METADATA_BUFFER, + "value", + baseObjectBuffer, + "typed_value", + fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + // the expected value is the shredded field value + ShreddedObject expected = Variants.object(TEST_METADATA); + expected.put("a", Variants.ofNull()); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(expected, actualVariant.value()); + } + + @Test + public void testNonObjectWithNullShreddedFields() throws IOException { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord variant = + record( + variantType, + Map.of("metadata", TEST_METADATA_BUFFER, "value", serialize(Variants.of(34)))); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + Record actual = writeAndRead(parquetSchema, record); + assertThat(actual.getField("id")).isEqualTo(1); + assertThat(actual.getField("var")).isInstanceOf(Variant.class); + + Variant actualVariant = (Variant) actual.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualVariant.metadata()); + VariantTestUtil.assertEqual(Variants.of(34), actualVariant.value()); + } + + @Test + public void testNonObjectWithNonNullShreddedFields() { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord a = record(fieldA, Map.of("value", serialize(Variants.ofNull()))); + GenericRecord b = record(fieldB, Map.of("value", serialize(Variants.of(9876543210L)))); + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record( + variantType, + Map.of( + "metadata", + TEST_METADATA_BUFFER, + "value", + serialize(Variants.of(34)), + "typed_value", + fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + assertThatThrownBy(() -> writeAndRead(parquetSchema, record)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid variant, non-object value with shredded fields"); + } + + @Test + public void testEmptyPartiallyShreddedObjectConflict() { + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType objectFields = objectFields(fieldA, fieldB); + GroupType variantType = variant("var", 2, objectFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord a = record(fieldA, Map.of()); // missing + GenericRecord b = record(fieldB, Map.of()); // missing + GenericRecord fields = record(objectFields, Map.of("a", a, "b", b)); + GenericRecord variant = + record( + variantType, + Map.of( + "metadata", + TEST_METADATA_BUFFER, + "value", + serialize(Variants.ofNull()), // conflicting non-object + "typed_value", + fields)); + GenericRecord record = record(parquetSchema, Map.of("id", 1, "var", variant)); + + assertThatThrownBy(() -> writeAndRead(parquetSchema, record)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid variant, non-object value with shredded fields"); + } + + @Test + public void testMixedRecords() throws IOException { + // tests multiple rows to check that Parquet columns are correctly advanced + GroupType fieldA = field("a", shreddedPrimitive(PrimitiveTypeName.INT32)); + GroupType fieldB = field("b", shreddedPrimitive(PrimitiveTypeName.BINARY, STRING)); + GroupType innerFields = objectFields(fieldA, fieldB); + GroupType fieldC = field("c", innerFields); + GroupType fieldD = field("d", shreddedPrimitive(PrimitiveTypeName.DOUBLE)); + GroupType outerFields = objectFields(fieldC, fieldD); + GroupType variantType = variant("var", 2, outerFields); + MessageType parquetSchema = parquetSchema(variantType); + + GenericRecord zero = record(parquetSchema, Map.of("id", 0)); + + GenericRecord a1 = record(fieldA, Map.of()); // missing + GenericRecord b1 = record(fieldB, Map.of("typed_value", "iceberg")); + GenericRecord inner1 = record(innerFields, Map.of("a", a1, "b", b1)); + GenericRecord c1 = record(fieldC, Map.of("typed_value", inner1)); + GenericRecord d1 = record(fieldD, Map.of()); // missing + GenericRecord outer1 = record(outerFields, Map.of("c", c1, "d", d1)); + GenericRecord variant1 = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", outer1)); + GenericRecord one = record(parquetSchema, Map.of("id", 1, "var", variant1)); + + ShreddedObject expectedC1 = Variants.object(TEST_METADATA); + expectedC1.put("b", Variants.of("iceberg")); + ShreddedObject expectedOne = Variants.object(TEST_METADATA); + expectedOne.put("c", expectedC1); + + GenericRecord c2 = record(fieldC, Map.of("value", serialize(Variants.of((byte) 8)))); + GenericRecord d2 = record(fieldD, Map.of("typed_value", -0.0D)); + GenericRecord outer2 = record(outerFields, Map.of("c", c2, "d", d2)); + GenericRecord variant2 = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", outer2)); + GenericRecord two = record(parquetSchema, Map.of("id", 2, "var", variant2)); + + ShreddedObject expectedTwo = Variants.object(TEST_METADATA); + expectedTwo.put("c", Variants.of((byte) 8)); + expectedTwo.put("d", Variants.of(-0.0D)); + + GenericRecord a3 = record(fieldA, Map.of("typed_value", 34)); + GenericRecord b3 = record(fieldB, Map.of("value", serialize(Variants.of("")))); + GenericRecord inner3 = record(innerFields, Map.of("a", a3, "b", b3)); + GenericRecord c3 = record(fieldC, Map.of("typed_value", inner3)); + GenericRecord d3 = record(fieldD, Map.of("typed_value", 0.0D)); + GenericRecord outer3 = record(outerFields, Map.of("c", c3, "d", d3)); + GenericRecord variant3 = + record(variantType, Map.of("metadata", TEST_METADATA_BUFFER, "typed_value", outer3)); + GenericRecord three = record(parquetSchema, Map.of("id", 3, "var", variant3)); + + ShreddedObject expectedC3 = Variants.object(TEST_METADATA); + expectedC3.put("a", Variants.of(34)); + expectedC3.put("b", Variants.of("")); + ShreddedObject expectedThree = Variants.object(TEST_METADATA); + expectedThree.put("c", expectedC3); + expectedThree.put("d", Variants.of(0.0D)); + + List records = writeAndRead(parquetSchema, List.of(zero, one, two, three)); + + Record actualZero = records.get(0); + assertThat(actualZero.getField("id")).isEqualTo(0); + assertThat(actualZero.getField("var")).isNull(); + + Record actualOne = records.get(1); + assertThat(actualOne.getField("id")).isEqualTo(1); + assertThat(actualOne.getField("var")).isInstanceOf(Variant.class); + + Variant actualOneVariant = (Variant) actualOne.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualOneVariant.metadata()); + VariantTestUtil.assertEqual(expectedOne, actualOneVariant.value()); + + Record actualTwo = records.get(2); + assertThat(actualTwo.getField("id")).isEqualTo(2); + assertThat(actualTwo.getField("var")).isInstanceOf(Variant.class); + + Variant actualTwoVariant = (Variant) actualTwo.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualTwoVariant.metadata()); + VariantTestUtil.assertEqual(expectedTwo, actualTwoVariant.value()); + + Record actualThree = records.get(3); + assertThat(actualThree.getField("id")).isEqualTo(3); + assertThat(actualThree.getField("var")).isInstanceOf(Variant.class); + + Variant actualThreeVariant = (Variant) actualThree.getField("var"); + VariantTestUtil.assertEqual(TEST_METADATA, actualThreeVariant.metadata()); + VariantTestUtil.assertEqual(expectedThree, actualThreeVariant.value()); + } + + private static ByteBuffer serialize(VariantValue value) { + ByteBuffer buffer = ByteBuffer.allocate(value.sizeInBytes()).order(ByteOrder.LITTLE_ENDIAN); + value.writeTo(buffer, 0); + return buffer; + } + + /** Creates an Avro record from a map of field name to value. */ + private static GenericRecord record(GroupType type, Map fields) { + GenericRecord record = new GenericData.Record(avroSchema(type)); + for (Map.Entry entry : fields.entrySet()) { + record.put(entry.getKey(), entry.getValue()); + } + return record; + } + + /** + * This is a custom Parquet writer builder that injects a specific Parquet schema and then uses + * the Avro object model. This ensures that the Parquet file's schema is exactly what was passed. + */ + private static class TestWriterBuilder + extends ParquetWriter.Builder { + private MessageType parquetSchema = null; + + protected TestWriterBuilder(OutputFile outputFile) { + super(ParquetIO.file(outputFile)); + } + + TestWriterBuilder withFileType(MessageType schema) { + this.parquetSchema = schema; + return self(); + } + + @Override + protected TestWriterBuilder self() { + return this; + } + + @Override + protected WriteSupport getWriteSupport(Configuration conf) { + return new AvroWriteSupport<>(parquetSchema, avroSchema(parquetSchema), GenericData.get()); + } + } + + static Record writeAndRead(MessageType parquetSchema, GenericRecord record) throws IOException { + return Iterables.getOnlyElement(writeAndRead(parquetSchema, List.of(record))); + } + + static List writeAndRead(MessageType parquetSchema, List records) + throws IOException { + OutputFile outputFile = new InMemoryOutputFile(); + + try (ParquetWriter writer = + new TestWriterBuilder(outputFile).withFileType(parquetSchema).build()) { + for (GenericRecord record : records) { + writer.write(record); + } + } + + try (CloseableIterable reader = + Parquet.read(outputFile.toInputFile()) + .project(SCHEMA) + .createReaderFunc(fileSchema -> InternalReader.create(SCHEMA, fileSchema)) + .build()) { + return Lists.newArrayList(reader); + } + } + + private static MessageType parquetSchema(Type variantType) { + return Types.buildMessage() + .required(PrimitiveTypeName.INT32) + .id(1) + .named("id") + .addField(variantType) + .named("table"); + } + + private static GroupType variant(String name, int fieldId) { + return Types.buildGroup(Type.Repetition.REQUIRED) + .id(fieldId) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .required(PrimitiveTypeName.BINARY) + .named("value") + .named(name); + } + + private static void checkShreddedType(Type shreddedType) { + Preconditions.checkArgument( + shreddedType.getName().equals("typed_value"), + "Invalid shredded type name: %s should be typed_value", + shreddedType.getName()); + Preconditions.checkArgument( + shreddedType.isRepetition(Type.Repetition.OPTIONAL), + "Invalid shredded type repetition: %s should be OPTIONAL", + shreddedType.getRepetition()); + } + + private static Type shreddedPrimitive(PrimitiveTypeName primitive) { + return Types.optional(primitive).named("typed_value"); + } + + private static Type shreddedPrimitive( + PrimitiveTypeName primitive, LogicalTypeAnnotation annotation) { + return Types.optional(primitive).as(annotation).named("typed_value"); + } + + private static Type shreddedType(VariantValue value) { + switch (value.type()) { + case BOOLEAN_TRUE: + case BOOLEAN_FALSE: + return shreddedPrimitive(PrimitiveTypeName.BOOLEAN); + case INT8: + return shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(8)); + case INT16: + return shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.intType(16)); + case INT32: + return shreddedPrimitive(PrimitiveTypeName.INT32); + case INT64: + return shreddedPrimitive(PrimitiveTypeName.INT64); + case FLOAT: + return shreddedPrimitive(PrimitiveTypeName.FLOAT); + case DOUBLE: + return shreddedPrimitive(PrimitiveTypeName.DOUBLE); + case DECIMAL4: + BigDecimal decimal4 = (BigDecimal) value.asPrimitive().get(); + return shreddedPrimitive( + PrimitiveTypeName.INT32, LogicalTypeAnnotation.decimalType(decimal4.scale(), 9)); + case DECIMAL8: + BigDecimal decimal8 = (BigDecimal) value.asPrimitive().get(); + return shreddedPrimitive( + PrimitiveTypeName.INT64, LogicalTypeAnnotation.decimalType(decimal8.scale(), 18)); + case DECIMAL16: + BigDecimal decimal16 = (BigDecimal) value.asPrimitive().get(); + return shreddedPrimitive( + PrimitiveTypeName.BINARY, LogicalTypeAnnotation.decimalType(decimal16.scale(), 38)); + case DATE: + return shreddedPrimitive(PrimitiveTypeName.INT32, LogicalTypeAnnotation.dateType()); + case TIMESTAMPTZ: + return shreddedPrimitive( + PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(true, TimeUnit.MICROS)); + case TIMESTAMPNTZ: + return shreddedPrimitive( + PrimitiveTypeName.INT64, LogicalTypeAnnotation.timestampType(false, TimeUnit.MICROS)); + case BINARY: + return shreddedPrimitive(PrimitiveTypeName.BINARY); + case STRING: + return shreddedPrimitive(PrimitiveTypeName.BINARY, STRING); + } + + throw new UnsupportedOperationException("Unsupported shredding type: " + value.type()); + } + + private static Object toAvroValue(VariantPrimitive variant) { + switch (variant.type()) { + case DECIMAL4: + return ((BigDecimal) variant.get()).unscaledValue().intValueExact(); + case DECIMAL8: + return ((BigDecimal) variant.get()).unscaledValue().longValueExact(); + case DECIMAL16: + return ((BigDecimal) variant.get()).unscaledValue().toByteArray(); + default: + return variant.get(); + } + } + + private static GroupType variant(String name, int fieldId, Type shreddedType) { + checkShreddedType(shreddedType); + return Types.buildGroup(Type.Repetition.OPTIONAL) + .id(fieldId) + .required(PrimitiveTypeName.BINARY) + .named("metadata") + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static void checkField(GroupType fieldType) { + Preconditions.checkArgument( + fieldType.isRepetition(Type.Repetition.REQUIRED), + "Invalid field type repetition: %s should be REQUIRED", + fieldType.getRepetition()); + } + + private static GroupType objectFields(GroupType... fields) { + for (GroupType fieldType : fields) { + checkField(fieldType); + } + + return Types.buildGroup(Type.Repetition.OPTIONAL).addFields(fields).named("typed_value"); + } + + private static GroupType field(String name, Type shreddedType) { + checkShreddedType(shreddedType); + return Types.buildGroup(Type.Repetition.REQUIRED) + .optional(PrimitiveTypeName.BINARY) + .named("value") + .addField(shreddedType) + .named(name); + } + + private static org.apache.avro.Schema avroSchema(GroupType schema) { + if (schema instanceof MessageType) { + return new AvroSchemaConverter().convert((MessageType) schema); + + } else { + MessageType wrapped = Types.buildMessage().addField(schema).named("table"); + org.apache.avro.Schema avro = + new AvroSchemaConverter().convert(wrapped).getFields().get(0).schema(); + switch (avro.getType()) { + case RECORD: + return avro; + case UNION: + return avro.getTypes().get(1); + } + + throw new IllegalArgumentException("Invalid converted type: " + avro); + } + } +}