Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add array reading support to native_datafusion scan #1324

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion native/core/src/execution/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType {
{
DatatypeStruct::List(info) => {
let field = Field::new(
"item",
"element",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was necessary to resolve an error trying to cast List(Field { name: "element", .. to List(Field { name: "item", .. in schema_adapter

to_arrow_datatype(info.element_type.as_ref().unwrap()),
info.contains_null,
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ object ParquetGenerator {
DataTypes.createDecimalType(10, 2),
DataTypes.createDecimalType(36, 18),
DataTypes.DateType,
DataTypes.TimestampType,
// TimestampNTZType only in Spark 3.4+
// DataTypes.TimestampNTZType,
DataTypes.StringType,
DataTypes.BinaryType)

Expand All @@ -58,6 +55,12 @@ object ParquetGenerator {
val dataTypes = ListBuffer[DataType]()
dataTypes.appendAll(primitiveTypes)

if (options.generateTimestamps) {
dataTypes += DataTypes.TimestampType
// TimestampNTZType only in Spark 3.4+
// dataTypes += DataTypes.TimestampNTZType,
}

if (options.generateStruct) {
dataTypes += StructType(
primitiveTypes.zipWithIndex.map(x => StructField(s"c${x._2}", x._1, true)))
Expand Down Expand Up @@ -212,8 +215,9 @@ object ParquetGenerator {
}

case class DataGenOptions(
allowNull: Boolean,
generateNegativeZero: Boolean,
generateArray: Boolean,
generateStruct: Boolean,
generateMap: Boolean)
allowNull: Boolean = true,
generateNegativeZero: Boolean = true,
generateTimestamps: Boolean = true,
generateArray: Boolean = false,
generateStruct: Boolean = false,
generateMap: Boolean = false)
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ object CometNativeScanExec extends DataTypeSupport {
}

override def isAdditionallySupported(dt: DataType): Boolean = {
// TODO add array and map
// TODO add map support
dt match {
case s: ArrayType => isTypeSupported(s.elementType)
case s: StructType => s.fields.map(_.dataType).forall(isTypeSupported)
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.{array, col, expr, lit, udf}
import org.apache.spark.sql.types.StructType

import org.apache.comet.CometSparkSessionExtensions.{isSpark34Plus, isSpark35Plus}
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}
Expand Down Expand Up @@ -55,17 +54,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
ParquetGenerator.makeParquetFile(
random,
spark,
filename,
100,
DataGenOptions(
allowNull = true,
generateNegativeZero = true,
generateArray = false,
generateStruct = false,
generateMap = false))
ParquetGenerator.makeParquetFile(random, spark, filename, 100, DataGenOptions())
}
val table = spark.read.parquet(filename)
table.createOrReplaceTempView("t1")
Expand All @@ -79,38 +68,31 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}

test("array_remove - test all types (convert from Parquet)") {
test("array_remove - test arrays (native_datafusion reader)") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val options = DataGenOptions(
allowNull = true,
generateNegativeZero = true,
generateArray = true,
generateStruct = true,
generateMap = false)
ParquetGenerator.makeParquetFile(random, spark, filename, 100, options)
ParquetGenerator.makeParquetFile(
random,
spark,
filename,
100,
DataGenOptions(
generateArray = true,
// native_datafusion does not support timestamps correctly yet
generateTimestamps = false))
}
withSQLConf(
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION) {
val table = spark.read.parquet(filename)
table.createOrReplaceTempView("t1")
// test with array of each column
for (field <- table.schema.fields) {
val fieldName = field.name
for (fieldName <- table.schema.fieldNames) {
sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1")
.createOrReplaceTempView("t2")
val df = sql("SELECT array_remove(a, b) FROM t2")
field.dataType match {
case _: StructType =>
// skip due to https://github.com/apache/datafusion-comet/issues/1314
case _ =>
checkSparkAnswer(df)
}
checkSparkAnswerAndOperator(df)
}
}
}
Expand Down
Loading