diff --git a/.github/workflows/spark_sql_test_ansi.yml b/.github/workflows/spark_sql_test_ansi.yml index 337e59efe..34a393115 100644 --- a/.github/workflows/spark_sql_test_ansi.yml +++ b/.github/workflows/spark_sql_test_ansi.yml @@ -22,17 +22,15 @@ concurrency: cancel-in-progress: true on: - # enable the following once Ansi support is completed - # push: - # paths-ignore: - # - "doc/**" - # - "**.md" - # pull_request: - # paths-ignore: - # - "doc/**" - # - "**.md" - - # manual trigger ONLY + push: + paths-ignore: + - "docs/**" + - "**.md" + pull_request: + paths-ignore: + - "docs/**" + - "**.md" + # manual trigger # https://docs.github.com/en/actions/managing-workflow-runs/manually-running-a-workflow workflow_dispatch: @@ -44,8 +42,8 @@ jobs: strategy: matrix: os: [ubuntu-latest] - java-version: [11] - spark-version: [{short: '3.4', full: '3.4.2'}] + java-version: [17] + spark-version: [{short: '4.0', full: '4.0.0-preview1'}] module: - {name: "catalyst", args1: "catalyst/test", args2: ""} - {name: "sql/core-1", args1: "", args2: sql/testOnly * -- -l org.apache.spark.tags.ExtendedSQLTest -l org.apache.spark.tags.SlowSQLTest} @@ -75,7 +73,8 @@ jobs: - name: Run Spark tests run: | cd apache-spark - ENABLE_COMET=true ENABLE_COMET_ANSI_MODE=true build/sbt ${{ matrix.module.args1 }} "${{ matrix.module.args2 }}" + rm -rf /root/.m2/repository/org/apache/parquet # somehow parquet cache requires cleanups + RUST_BACKTRACE=1 ENABLE_COMET=true ENABLE_COMET_ANSI_MODE=true build/sbt ${{ matrix.module.args1 }} "${{ matrix.module.args2 }}" env: LC_ALL: "C.UTF-8" diff --git a/common/src/main/java/org/apache/comet/parquet/BatchReader.java b/common/src/main/java/org/apache/comet/parquet/BatchReader.java index bf8e6e550..4b63f84ef 100644 --- a/common/src/main/java/org/apache/comet/parquet/BatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/BatchReader.java @@ -285,6 +285,7 @@ public void init() throws URISyntaxException, IOException { missingColumns = new boolean[columns.size()]; List paths = requestedSchema.getPaths(); StructField[] nonPartitionFields = sparkSchema.fields(); + ShimFileFormat.findRowIndexColumnIndexInSchema(sparkSchema); for (int i = 0; i < requestedSchema.getFieldCount(); i++) { Type t = requestedSchema.getFields().get(i); Preconditions.checkState( diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java b/common/src/main/java/org/apache/comet/parquet/Native.java index 0887ae12f..b40e27e73 100644 --- a/common/src/main/java/org/apache/comet/parquet/Native.java +++ b/common/src/main/java/org/apache/comet/parquet/Native.java @@ -75,6 +75,7 @@ public static native long initColumnReader( int precision, int expectedPrecision, int scale, + int expectedScale, int tu, boolean isAdjustedUtc, int batchSize, diff --git a/common/src/main/java/org/apache/comet/parquet/TypeUtil.java b/common/src/main/java/org/apache/comet/parquet/TypeUtil.java index b8b7ff525..bfbb7d0d2 100644 --- a/common/src/main/java/org/apache/comet/parquet/TypeUtil.java +++ b/common/src/main/java/org/apache/comet/parquet/TypeUtil.java @@ -27,6 +27,7 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; import org.apache.parquet.schema.Types; +import org.apache.spark.package$; import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; import org.apache.spark.sql.types.*; @@ -169,6 +170,7 @@ && isUnsignedIntTypeMatched(logicalTypeAnnotation, 64)) { break; case INT96: if (sparkType == TimestampNTZType$.MODULE$) { + if (isSpark40Plus()) return; // Spark 4.0+ supports Timestamp NTZ with INT96 convertErrorForTimestampNTZ(typeName.name()); } else if (sparkType == DataTypes.TimestampType) { return; @@ -218,7 +220,8 @@ private static void validateTimestampType( // Throw an exception if the Parquet type is TimestampLTZ and the Catalyst type is TimestampNTZ. // This is to avoid mistakes in reading the timestamp values. if (((TimestampLogicalTypeAnnotation) logicalTypeAnnotation).isAdjustedToUTC() - && sparkType == TimestampNTZType$.MODULE$) { + && sparkType == TimestampNTZType$.MODULE$ + && !isSpark40Plus()) { convertErrorForTimestampNTZ("int64 time(" + logicalTypeAnnotation + ")"); } } @@ -232,12 +235,14 @@ private static void convertErrorForTimestampNTZ(String parquetType) { } private static boolean canReadAsIntDecimal(ColumnDescriptor descriptor, DataType dt) { - if (!DecimalType.is32BitDecimalType(dt)) return false; + if (!DecimalType.is32BitDecimalType(dt) && !(isSpark40Plus() && dt instanceof DecimalType)) + return false; return isDecimalTypeMatched(descriptor, dt); } private static boolean canReadAsLongDecimal(ColumnDescriptor descriptor, DataType dt) { - if (!DecimalType.is64BitDecimalType(dt)) return false; + if (!DecimalType.is64BitDecimalType(dt) && !(isSpark40Plus() && dt instanceof DecimalType)) + return false; return isDecimalTypeMatched(descriptor, dt); } @@ -261,7 +266,9 @@ private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataTyp DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation; // It's OK if the required decimal precision is larger than or equal to the physical decimal // precision in the Parquet metadata, as long as the decimal scale is the same. - return decimalType.getPrecision() <= d.precision() && decimalType.getScale() == d.scale(); + return decimalType.getPrecision() <= d.precision() + && (decimalType.getScale() == d.scale() + || (isSpark40Plus() && decimalType.getScale() <= d.scale())); } return false; } @@ -278,4 +285,8 @@ private static boolean isUnsignedIntTypeMatched( && !((IntLogicalTypeAnnotation) logicalTypeAnnotation).isSigned() && ((IntLogicalTypeAnnotation) logicalTypeAnnotation).getBitWidth() == bitWidth; } + + private static boolean isSpark40Plus() { + return package$.MODULE$.SPARK_VERSION().compareTo("4.0") >= 0; + } } diff --git a/common/src/main/java/org/apache/comet/parquet/Utils.java b/common/src/main/java/org/apache/comet/parquet/Utils.java index 99f3a4edd..2d4b83a67 100644 --- a/common/src/main/java/org/apache/comet/parquet/Utils.java +++ b/common/src/main/java/org/apache/comet/parquet/Utils.java @@ -115,7 +115,7 @@ public static long initColumnReader( promotionInfo = new TypePromotionInfo(readType); } else { // If type promotion is not enable, we'll just use the Parquet primitive type and precision. - promotionInfo = new TypePromotionInfo(primitiveTypeId, precision); + promotionInfo = new TypePromotionInfo(primitiveTypeId, precision, scale); } return Native.initColumnReader( @@ -131,6 +131,7 @@ public static long initColumnReader( precision, promotionInfo.precision, scale, + promotionInfo.scale, tu, isAdjustedUtc, batchSize, @@ -144,10 +145,13 @@ static class TypePromotionInfo { int physicalTypeId; // Decimal precision from the Spark read schema, or -1 if it's not decimal type. int precision; + // Decimal scale from the Spark read schema, or -1 if it's not decimal type. + int scale; - TypePromotionInfo(int physicalTypeId, int precision) { + TypePromotionInfo(int physicalTypeId, int precision, int scale) { this.physicalTypeId = physicalTypeId; this.precision = precision; + this.scale = scale; } TypePromotionInfo(DataType sparkReadType) { @@ -159,13 +163,16 @@ static class TypePromotionInfo { int physicalTypeId = getPhysicalTypeId(primitiveType.getPrimitiveTypeName()); LogicalTypeAnnotation annotation = primitiveType.getLogicalTypeAnnotation(); int precision = -1; + int scale = -1; if (annotation instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) { LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalAnnotation = (LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) annotation; precision = decimalAnnotation.getPrecision(); + scale = decimalAnnotation.getScale(); } this.physicalTypeId = physicalTypeId; this.precision = precision; + this.scale = scale; } } diff --git a/common/src/main/spark-3.x/org/apache/comet/shims/ShimFileFormat.scala b/common/src/main/spark-3.x/org/apache/comet/shims/ShimFileFormat.scala index 685e8f566..c34c947b5 100644 --- a/common/src/main/spark-3.x/org/apache/comet/shims/ShimFileFormat.scala +++ b/common/src/main/spark-3.x/org/apache/comet/shims/ShimFileFormat.scala @@ -19,6 +19,8 @@ package org.apache.comet.shims +import org.apache.spark.sql.types.{LongType, StructField, StructType} + object ShimFileFormat { // TODO: remove after dropping Spark 3.3 support and directly use FileFormat.ROW_INDEX @@ -29,4 +31,20 @@ object ShimFileFormat { // TODO: remove after dropping Spark 3.3 support and directly use // FileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME val ROW_INDEX_TEMPORARY_COLUMN_NAME: String = s"_tmp_metadata_$ROW_INDEX" + + // TODO: remove after dropping Spark 3.3 support and directly use + // RowIndexUtil.findRowIndexColumnIndexInSchema + def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = { + sparkSchema.fields.zipWithIndex.find { case (field: StructField, _: Int) => + field.name == ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME + } match { + case Some((field: StructField, idx: Int)) => + if (field.dataType != LongType) { + throw new RuntimeException( + s"${ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME} must be of LongType") + } + idx + case _ => -1 + } + } } diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala b/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala index 2f386869a..1702db135 100644 --- a/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala +++ b/common/src/main/spark-4.0/org/apache/comet/shims/ShimFileFormat.scala @@ -19,13 +19,15 @@ package org.apache.comet.shims -import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.parquet.ParquetRowIndexUtil +import org.apache.spark.sql.types.StructType object ShimFileFormat { // A name for a temporary column that holds row indexes computed by the file format reader // until they can be placed in the _metadata struct. val ROW_INDEX_TEMPORARY_COLUMN_NAME = ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME - val OPTION_RETURNING_BATCH = FileFormat.OPTION_RETURNING_BATCH + def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = + ParquetRowIndexUtil.findRowIndexColumnIndexInSchema(sparkSchema) } diff --git a/core/benches/parquet_read.rs b/core/benches/parquet_read.rs index 612d081c7..32463c077 100644 --- a/core/benches/parquet_read.rs +++ b/core/benches/parquet_read.rs @@ -54,7 +54,7 @@ fn bench(c: &mut Criterion) { ); b.iter(|| { let cd = ColumnDescriptor::new(t.clone(), 0, 0, ColumnPath::from(Vec::new())); - let promition_info = TypePromotionInfo::new(PhysicalType::INT32, -1); + let promition_info = TypePromotionInfo::new(PhysicalType::INT32, -1, -1); let mut column_reader = TestColumnReader::new( cd, promition_info, diff --git a/core/src/errors.rs b/core/src/errors.rs index b38c5e90b..b6f4d0889 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -21,6 +21,7 @@ use arrow::error::ArrowError; use datafusion_common::DataFusionError; use jni::errors::{Exception, ToException}; use regex::Regex; + use std::{ any::Any, convert, @@ -37,6 +38,7 @@ use std::{ use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, jshort}; use crate::execution::operators::ExecutionError; +use jni::objects::{GlobalRef, JThrowable}; use jni::JNIEnv; use lazy_static::lazy_static; use parquet::errors::ParquetError; @@ -160,7 +162,11 @@ pub enum CometError { }, #[error("{class}: {msg}")] - JavaException { class: String, msg: String }, + JavaException { + class: String, + msg: String, + throwable: GlobalRef, + }, } pub fn init() { @@ -208,6 +214,15 @@ impl From for ExecutionError { fn from(value: CometError) -> Self { match value { CometError::Execution { source } => source, + CometError::JavaException { + class, + msg, + throwable, + } => ExecutionError::JavaException { + class, + msg, + throwable, + }, _ => ExecutionError::GeneralError(value.to_string()), } } @@ -379,17 +394,34 @@ pub fn unwrap_or_throw_default( } } -fn throw_exception(env: &mut JNIEnv, error: &E, backtrace: Option) { +fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option) { // If there isn't already an exception? if env.exception_check().is_ok() { // ... then throw new exception - let exception = error.to_exception(); - match backtrace { - Some(backtrace_string) => env.throw_new( - exception.class, - to_stacktrace_string(exception.msg, backtrace_string).unwrap(), - ), - _ => env.throw_new(exception.class, exception.msg), + match error { + CometError::JavaException { + class: _, + msg: _, + throwable, + } => env.throw(<&JThrowable>::from(throwable.as_obj())), + CometError::Execution { + source: + ExecutionError::JavaException { + class: _, + msg: _, + throwable, + }, + } => env.throw(<&JThrowable>::from(throwable.as_obj())), + _ => { + let exception = error.to_exception(); + match backtrace { + Some(backtrace_string) => env.throw_new( + exception.class, + to_stacktrace_string(exception.msg, backtrace_string).unwrap(), + ), + _ => env.throw_new(exception.class, exception.msg), + } + } } .expect("Thrown exception") } diff --git a/core/src/execution/operators/mod.rs b/core/src/execution/operators/mod.rs index 13a0d9627..d0cc7ac68 100644 --- a/core/src/execution/operators/mod.rs +++ b/core/src/execution/operators/mod.rs @@ -25,6 +25,7 @@ use arrow::{ use arrow::compute::{cast_with_options, CastOptions}; use arrow_schema::ArrowError; +use jni::objects::GlobalRef; use std::{fmt::Debug, sync::Arc}; mod scan; @@ -52,6 +53,13 @@ pub enum ExecutionError { /// DataFusion error #[error("Error from DataFusion: {0}.")] DataFusionError(String), + + #[error("{class}: {msg}")] + JavaException { + class: String, + msg: String, + throwable: GlobalRef, + }, } /// Copy an Arrow Array diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs index 41376f03b..3f61c0324 100644 --- a/core/src/jvm_bridge/mod.rs +++ b/core/src/jvm_bridge/mod.rs @@ -385,5 +385,6 @@ pub(crate) fn convert_exception( Ok(CometError::JavaException { class: exception_class_name_str, msg: message_str, + throwable: env.new_global_ref(throwable)?, }) } diff --git a/core/src/parquet/mod.rs b/core/src/parquet/mod.rs index 4f87d15de..e6acaa26b 100644 --- a/core/src/parquet/mod.rs +++ b/core/src/parquet/mod.rs @@ -72,6 +72,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( precision: jint, read_precision: jint, scale: jint, + read_scale: jint, time_unit: jint, is_adjusted_utc: jboolean, batch_size: jint, @@ -94,7 +95,8 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( is_adjusted_utc, jni_path, )?; - let promotion_info = TypePromotionInfo::new_from_jni(read_primitive_type, read_precision); + let promotion_info = + TypePromotionInfo::new_from_jni(read_primitive_type, read_precision, read_scale); let ctx = Context { column_reader: ColumnReader::get( desc, diff --git a/core/src/parquet/read/column.rs b/core/src/parquet/read/column.rs index 6fc73f93f..22bade6b3 100644 --- a/core/src/parquet/read/column.rs +++ b/core/src/parquet/read/column.rs @@ -132,11 +132,17 @@ impl ColumnReader { (32, false) => typed_reader!(UInt32ColumnReader, Int64), _ => unimplemented!("Unsupported INT32 annotation: {:?}", lt), }, - LogicalType::Decimal { scale, precision } => { - if use_decimal_128 { + LogicalType::Decimal { + scale, + precision: _, + } => { + if use_decimal_128 || scale < &promotion_info.scale { typed_reader!( Int32DecimalColumnReader, - ArrowDataType::Decimal128(*precision as u8, *scale as i8) + ArrowDataType::Decimal128( + promotion_info.precision as u8, + promotion_info.scale as i8 + ) ) } else { typed_reader!(Int32ColumnReader, Int32) @@ -168,11 +174,17 @@ impl ColumnReader { ), _ => panic!("Unsupported INT64 annotation: {:?}", lt), }, - LogicalType::Decimal { scale, precision } => { - if use_decimal_128 { + LogicalType::Decimal { + scale, + precision: _, + } => { + if use_decimal_128 || scale < &promotion_info.scale { typed_reader!( Int64DecimalColumnReader, - ArrowDataType::Decimal128(*precision as u8, *scale as i8) + ArrowDataType::Decimal128( + promotion_info.precision as u8, + promotion_info.scale as i8 + ) ) } else { typed_reader!(Int64ColumnReader, Int64) @@ -248,7 +260,10 @@ impl ColumnReader { PhysicalType::FIXED_LEN_BYTE_ARRAY => { if let Some(logical_type) = desc.logical_type() { match logical_type { - LogicalType::Decimal { precision, scale } => { + LogicalType::Decimal { + precision, + scale: _, + } => { if !use_decimal_128 && precision <= DECIMAL_MAX_INT_DIGITS { typed_reader!(FLBADecimal32ColumnReader, Int32) } else if !use_decimal_128 && precision <= DECIMAL_MAX_LONG_DIGITS { @@ -256,7 +271,10 @@ impl ColumnReader { } else { typed_reader!( FLBADecimalColumnReader, - ArrowDataType::Decimal128(precision as u8, scale as i8) + ArrowDataType::Decimal128( + promotion_info.precision as u8, + promotion_info.scale as i8 + ) ) } } diff --git a/core/src/parquet/read/values.rs b/core/src/parquet/read/values.rs index 7f1195fa9..ebed5f95b 100644 --- a/core/src/parquet/read/values.rs +++ b/core/src/parquet/read/values.rs @@ -28,6 +28,7 @@ use crate::{ parquet::{data_type::*, read::DECIMAL_BYTE_WIDTH, ParquetMutableVector}, unlikely, }; +use arrow::datatypes::DataType as ArrowDataType; pub fn get_decoder( value_data: Buffer, @@ -651,6 +652,12 @@ macro_rules! make_plain_decimal_impl { debug_assert!(byte_width <= DECIMAL_BYTE_WIDTH); + let src_scale = src.desc.type_scale() as u32; + let dst_scale = match dst.arrow_type { + ArrowDataType::Decimal128(_percision, scale) => scale as u32, + _ => unreachable!() + }; + for _ in 0..num { let s = &mut dst_data[dst_offset..]; @@ -674,6 +681,15 @@ macro_rules! make_plain_decimal_impl { } } + if dst_scale > src_scale { + let exp = dst_scale - src_scale; + let mul = 10_i128.pow(exp); + let v = s.as_mut_ptr() as *mut i128; + unsafe { + v.write_unaligned(v.read_unaligned() * mul); + } + } + src_offset += byte_width; dst_offset += DECIMAL_BYTE_WIDTH; } diff --git a/core/src/parquet/util/jni.rs b/core/src/parquet/util/jni.rs index 62787213f..cde9fff0f 100644 --- a/core/src/parquet/util/jni.rs +++ b/core/src/parquet/util/jni.rs @@ -96,21 +96,24 @@ pub fn convert_encoding(ordinal: jint) -> Encoding { pub struct TypePromotionInfo { pub(crate) physical_type: PhysicalType, pub(crate) precision: i32, + pub(crate) scale: i32, } impl TypePromotionInfo { - pub fn new_from_jni(physical_type_id: jint, precision: jint) -> Self { + pub fn new_from_jni(physical_type_id: jint, precision: jint, scale: jint) -> Self { let physical_type = convert_physical_type(physical_type_id); Self { physical_type, precision, + scale, } } - pub fn new(physical_type: PhysicalType, precision: i32) -> Self { + pub fn new(physical_type: PhysicalType, precision: i32, scale: i32) -> Self { Self { physical_type, precision, + scale, } } } diff --git a/dev/diffs/4.0.0-preview1.diff b/dev/diffs/4.0.0-preview1.diff new file mode 100644 index 000000000..4031015df --- /dev/null +++ b/dev/diffs/4.0.0-preview1.diff @@ -0,0 +1,2816 @@ +diff --git a/pom.xml b/pom.xml +index a4b1b2c3c9f..a2315d2a95b 100644 +--- a/pom.xml ++++ b/pom.xml +@@ -147,6 +147,8 @@ + 0.10.0 + 2.5.2 + 2.0.8 ++ 4.0 ++ 0.1.0-SNAPSHOT + + + org.apache.datasketches +diff --git a/sql/core/pom.xml b/sql/core/pom.xml +index 19f6303be36..31e1d27700f 100644 +--- a/sql/core/pom.xml ++++ b/sql/core/pom.xml +@@ -77,6 +77,10 @@ + org.apache.spark + spark-tags_${scala.binary.version} + ++ ++ org.apache.comet ++ comet-spark-spark${spark.version.short}_${scala.binary.version} ++ + + + spark-4.0 - 2.13.13 + 2.13.14 2.13 4.0.0-preview1 4.0 1.13.1 + 4.9.5 + 2.0.13 spark-4.0 not-needed-yet @@ -632,7 +643,7 @@ under the License. org.scalameta semanticdb-scalac_${scala.version} - 4.8.8 + ${semanticdb.version} diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index e939b43a1..c19395684 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometEnabled, isCometExecEnabled, isCometJVMShuffleMode, isCometNativeShuffleMode, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, shouldApplyRowToColumnar, withInfo, withInfos} +import org.apache.comet.CometSparkSessionExtensions.{createMessage, getCometShuffleNotEnabledReason, isANSIEnabled, isCometBroadCastForceEnabled, isCometEnabled, isCometExecEnabled, isCometJVMShuffleMode, isCometNativeShuffleMode, isCometOperatorEnabled, isCometScan, isCometScanEnabled, isCometShuffleEnabled, isSchemaSupported, isSpark34Plus, isSpark40Plus, shouldApplyRowToColumnar, withInfo, withInfos} import org.apache.comet.parquet.{CometParquetScan, SupportsComet} import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.QueryPlanSerde @@ -715,7 +715,9 @@ class CometSparkSessionExtensions // enabled. if (isANSIEnabled(conf)) { if (COMET_ANSI_MODE_ENABLED.get()) { - logWarning("Using Comet's experimental support for ANSI mode.") + if (!isSpark40Plus) { + logWarning("Using Comet's experimental support for ANSI mode.") + } } else { logInfo("Comet extension disabled for ANSI mode") return plan diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 67ecfe52d..13abaa0c4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -63,7 +63,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: DecimalType | _: DateType | _: BooleanType | _: NullType => true - case dt if dt.typeName == "timestamp_ntz" => true + case dt if isTimestampNTZType(dt) => true case dt => emitWarning(s"unsupported Spark data type: $dt") false @@ -87,7 +87,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case _: BinaryType => 8 case _: TimestampType => 9 case _: DecimalType => 10 - case dt if dt.typeName == "timestamp_ntz" => 11 + case dt if isTimestampNTZType(dt) => 11 case _: DateType => 12 case _: NullType => 13 case _: ArrayType => 14 @@ -1033,6 +1033,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) exprBuilder.setBytesVal(byteStr) case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case dt if isTimestampNTZType(dt) => + exprBuilder.setLongVal(value.asInstanceOf[Long]) case dt => logWarning(s"Unexpected date type '$dt' for literal value '$value'") } @@ -2241,8 +2243,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: BooleanType => true - // `TimestampNTZType` is private in Spark 3.3. - case dt if dt.typeName == "timestamp_ntz" => true + case dt if isTimestampNTZType(dt) => true case _ => false } diff --git a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala index 150656c23..aa6db06d8 100644 --- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. @@ -32,5 +33,8 @@ trait CometExprShim { (unhex.child, Literal(false)) } + protected def isTimestampNTZType(dt: DataType): Boolean = + dt.typeName == "timestamp_ntz" // `TimestampNTZType` is private + protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalMode.fromBoolean(c.ansiEnabled) } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala index 5f4e3fba2..7709957b4 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{DataType, TimestampNTZType} /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. @@ -32,6 +33,11 @@ trait CometExprShim { (unhex.child, Literal(unhex.failOnError)) } + protected def isTimestampNTZType(dt: DataType): Boolean = dt match { + case _: TimestampNTZType => true + case _ => false + } + protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) } diff --git a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala index 65fb59a38..900b19895 100644 --- a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala +++ b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions import org.apache.spark.sql.execution.{FileSourceScanExec, PartitionedFileUtil} import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, HadoopFsRelation, PartitionDirectory, PartitionedFile} -import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.sql.types.StructType trait ShimCometScanExec { def wrapped: FileSourceScanExec @@ -87,24 +87,9 @@ trait ShimCometScanExec { .asInstanceOf[SparkException] } - // Copied from Spark 3.4 RowIndexUtil due to PARQUET-2161 (tracked in SPARK-39634) - // TODO: remove after PARQUET-2161 becomes available in Parquet - private def findRowIndexColumnIndexInSchema(sparkSchema: StructType): Int = { - sparkSchema.fields.zipWithIndex.find { case (field: StructField, _: Int) => - field.name == ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME - } match { - case Some((field: StructField, idx: Int)) => - if (field.dataType != LongType) { - throw new RuntimeException( - s"${ShimFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME} must be of LongType") - } - idx - case _ => -1 - } - } - protected def isNeededForSchema(sparkSchema: StructType): Boolean = { - findRowIndexColumnIndexInSchema(sparkSchema) >= 0 + // TODO: remove after PARQUET-2161 becomes available in Parquet (tracked in SPARK-39634) + ShimFileFormat.findRowIndexColumnIndexInSchema(sparkSchema) >= 0 } protected def getPartitionedFile(f: FileStatus, p: PartitionDirectory): PartitionedFile = diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 5f4e3fba2..7709957b4 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -20,6 +20,7 @@ package org.apache.comet.shims import org.apache.comet.expressions.CometEvalMode import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{DataType, TimestampNTZType} /** * `CometExprShim` acts as a shim for for parsing expressions from different Spark versions. @@ -32,6 +33,11 @@ trait CometExprShim { (unhex.child, Literal(unhex.failOnError)) } + protected def isTimestampNTZType(dt: DataType): Boolean = dt match { + case _: TimestampNTZType => true + case _ => false + } + protected def evalMode(c: Cast): CometEvalMode.Value = CometEvalModeUtil.fromSparkEvalMode(c.evalMode) } diff --git a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala index bc18d8f10..9e6cbc0a6 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala @@ -1116,7 +1116,7 @@ abstract class ParquetReadSuite extends CometTestBase { } test("row group skipping doesn't overflow when reading into larger type") { - // Spark 4.0 no longer fails for widening types + // Spark 4.0 no longer fails for widening types SPARK-40876 // https://github.com/apache/spark/commit/3361f25dc0ff6e5233903c26ee105711b79ba967 assume(isSpark34Plus && !isSpark40Plus)