From 2cac768fabaa7cad40390b2205dd9c5000011e4c Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Tue, 14 Nov 2023 13:55:25 +0900 Subject: [PATCH] [SPARK-45844][SQL] Implement case-insensitivity for XML ### What changes were proposed in this pull request? This PR addresses the current lack of support for case-insensitive schema handling in XML file formats. Our approach now follows the `SQLConf` case insensitivity setting in both schema inference and file read operations. We handle duplicate keys in the following behavior: 1. When we encounter duplicates (whether case-sensitive or not) in a row, we will convert them into an array and pick the first one we encounter as the array's name. 2. When we encounter duplicates across rows, we will also respect the first one we encounter Keys of the map-type data are string types and are not treated as field names, thereby not requiring case-sensitivity checks. ### Why are the changes needed? To keep consistent with other file formats and reduce maintenance efforts. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #43722 from shujingyang-db/case-sensitive. Lead-authored-by: Shujing Yang Co-authored-by: Shujing Yang <135740748+shujingyang-db@users.noreply.github.com> Signed-off-by: Hyukjin Kwon --- .../catalyst/expressions/xmlExpressions.scala | 3 +- .../sql/catalyst/xml/StaxXmlParser.scala | 17 ++- .../sql/catalyst/xml/XmlInferSchema.scala | 91 ++++++++---- .../datasources/xml/XmlDataSource.scala | 7 +- .../attributes-case-sensitive.xml | 12 ++ .../execution/datasources/xml/XmlSuite.scala | 139 ++++++++++++++++++ 6 files changed, 235 insertions(+), 34 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index c581643460f65..27c0a09fa1f06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -189,7 +189,8 @@ case class SchemaOfXml( private lazy val xmlFactory = xmlOptions.buildXmlFactory() @transient - private lazy val xmlInferSchema = new XmlInferSchema(xmlOptions) + private lazy val xmlInferSchema = + new XmlInferSchema(xmlOptions, caseSensitive = SQLConf.get.caseSensitiveAnalysis) @transient private lazy val xml = child.eval().asInstanceOf[UTF8String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 77a0bd1dff179..754b54ce1575c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -35,10 +35,11 @@ import org.apache.spark.SparkUpgradeException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, CaseInsensitiveMap, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -87,6 +88,14 @@ class StaxXmlParser( } } + private def getFieldNameToIndex(schema: StructType): Map[String, Int] = { + if (SQLConf.get.caseSensitiveAnalysis) { + schema.map(_.name).zipWithIndex.toMap + } else { + CaseInsensitiveMap(schema.map(_.name).zipWithIndex.toMap) + } + } + def parseStream( inputStream: InputStream, schema: StructType): Iterator[InternalRow] = { @@ -274,7 +283,7 @@ class StaxXmlParser( val convertedValuesMap = collection.mutable.Map.empty[String, Any] val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options) valuesMap.foreach { case (f, v) => - val nameToIndex = schema.map(_.name).zipWithIndex.toMap + val nameToIndex = getFieldNameToIndex(schema) nameToIndex.get(f).foreach { i => convertedValuesMap(f) = convertTo(v, schema(i).dataType) } @@ -313,7 +322,7 @@ class StaxXmlParser( // Here we merge both to a row. val valuesMap = fieldsMap ++ attributesMap valuesMap.foreach { case (f, v) => - val nameToIndex = schema.map(_.name).zipWithIndex.toMap + val nameToIndex = getFieldNameToIndex(schema) nameToIndex.get(f).foreach { row(_) = v } } @@ -335,7 +344,7 @@ class StaxXmlParser( rootAttributes: Array[Attribute] = Array.empty, isRootAttributesOnly: Boolean = false): InternalRow = { val row = new Array[Any](schema.length) - val nameToIndex = schema.map(_.name).zipWithIndex.toMap + val nameToIndex = getFieldNameToIndex(schema) // If there are attributes, then we process them first. convertAttributes(rootAttributes, schema).toSeq.foreach { case (f, v) => nameToIndex.get(f).foreach { row(_) = v } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index 777dd69fd7fa0..25f33e7f1bbdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -36,7 +36,9 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, PermissiveMode, Timest import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.types._ -private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with Logging { +private[sql] class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) + extends Serializable + with Logging { private val decimalParser = ExprUtils.getDecimalParser(options.locale) @@ -115,8 +117,7 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with } } - def infer(xml: String, - xsdSchema: Option[Schema] = None): Option[DataType] = { + def infer(xml: String, xsdSchema: Option[Schema] = None): Option[DataType] = { try { val xsd = xsdSchema.orElse(Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)) xsd.foreach { schema => @@ -199,14 +200,50 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with private def inferObject( parser: XMLEventReader, rootAttributes: Array[Attribute] = Array.empty): DataType = { - val builder = ArrayBuffer[StructField]() - val nameToDataType = collection.mutable.Map.empty[String, ArrayBuffer[DataType]] + /** + * Retrieves the field name with respect to the case sensitivity setting. + * We pick the first name we encountered. + * + * If case sensitivity is enabled, the original field name is returned. + * If not, the field name is managed in a case-insensitive map. + * + * For instance, if we encounter the following field names: + * foo, Foo, FOO + * + * In case-sensitive mode: we will infer three fields: foo, Foo, FOO + * In case-insensitive mode, we will infer an array named by foo + * (as it's the first one we encounter) + */ + val caseSensitivityOrdering: Ordering[String] = (x: String, y: String) => + if (caseSensitive) { + x.compareTo(y) + } else { + x.compareToIgnoreCase(y) + } + + val nameToDataType = + collection.mutable.TreeMap.empty[String, DataType](caseSensitivityOrdering) + + def addOrUpdateType(fieldName: String, newType: DataType): Unit = { + val oldTypeOpt = nameToDataType.get(fieldName) + oldTypeOpt match { + // If the field name exists in the map, + // merge the type and infer the combined field as an array type if necessary + case Some(oldType) if !oldType.isInstanceOf[ArrayType] => + nameToDataType.update(fieldName, ArrayType(compatibleType(oldType, newType))) + case Some(oldType) => + nameToDataType.update(fieldName, compatibleType(oldType, newType)) + case None => + nameToDataType.put(fieldName, newType) + } + } + // If there are attributes, then we should process them first. val rootValuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) rootValuesMap.foreach { case (f, v) => - nameToDataType += (f -> ArrayBuffer(inferFrom(v))) + addOrUpdateType(f, inferFrom(v)) } var shouldStop = false while (!shouldStop) { @@ -239,14 +276,12 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with } // Add the field and datatypes so that we can check if this is ArrayType. val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) - val dataTypes = nameToDataType.getOrElse(field, ArrayBuffer.empty[DataType]) - dataTypes += inferredType - nameToDataType += (field -> dataTypes) + addOrUpdateType(field, inferredType) case c: Characters if !c.isWhiteSpace => // This can be an attribute-only object val valueTagType = inferFrom(c.getData) - nameToDataType += options.valueTag -> ArrayBuffer(valueTagType) + addOrUpdateType(options.valueTag, valueTagType) case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser) @@ -258,25 +293,17 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with // if it only consists of attributes and valueTags. // If not, we will remove the valueTag field from the schema val attributesOnly = nameToDataType.forall { - case (fieldName, dataTypes) => - dataTypes.length == 1 && - (fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix)) + case (fieldName, _) => + fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix) } if (!attributesOnly) { nameToDataType -= options.valueTag } - // We need to manually merges the fields having the sames so that - // This can be inferred as ArrayType. - nameToDataType.foreach { - case (field, dataTypes) if dataTypes.length > 1 => - val elementType = dataTypes.reduceLeft(compatibleType) - builder += StructField(field, ArrayType(elementType), nullable = true) - case (field, dataTypes) => - builder += StructField(field, dataTypes.head, nullable = true) - } // Note: other code relies on this sorting for correctness, so don't remove it! - StructType(builder.sortBy(_.name).toArray) + StructType(nameToDataType.map{ + case (name, dataType) => StructField(name, dataType) + }.toList.sortBy(_.name)) } /** @@ -384,7 +411,12 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType): DataType = { + private[xml] def compatibleType(t1: DataType, t2: DataType): DataType = { + + def normalize(name: String): String = { + if (caseSensitive) name else name.toLowerCase(Locale.ROOT) + } + // TODO: Optimise this logic. findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. @@ -406,10 +438,15 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with } case (StructType(fields1), StructType(fields2)) => - val newFields = (fields1 ++ fields2).groupBy(_.name).map { - case (name, fieldTypes) => + val newFields = (fields1 ++ fields2) + // normalize field name and pair it with original field + .map(field => (normalize(field.name), field)) + .groupBy(_._1) // group by normalized field name + .map { case (_: String, fields: Array[(String, StructField)]) => + val fieldTypes = fields.map(_._2) val dataType = fieldTypes.map(_.dataType).reduce(compatibleType) - StructField(name, dataType, nullable = true) + // we pick up the first field name that we've encountered for the field + StructField(fields.head._2.name, dataType) } StructType(newFields.toArray.sortBy(_.name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala index b09be84130abb..4b3c82bd83bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala @@ -122,7 +122,8 @@ object TextInputXmlDataSource extends XmlDataSource { xml: Dataset[String], parsedOptions: XmlOptions): StructType = { SQLExecution.withSQLConfPropagated(xml.sparkSession) { - new XmlInferSchema(parsedOptions).infer(xml.rdd) + new XmlInferSchema(parsedOptions, xml.sparkSession.sessionState.conf.caseSensitiveAnalysis) + .infer(xml.rdd) } } @@ -179,7 +180,9 @@ object MultiLineXmlDataSource extends XmlDataSource { parsedOptions) } SQLExecution.withSQLConfPropagated(sparkSession) { - val schema = new XmlInferSchema(parsedOptions).infer(tokenRDD) + val schema = + new XmlInferSchema(parsedOptions, sparkSession.sessionState.conf.caseSensitiveAnalysis) + .infer(tokenRDD) schema } } diff --git a/sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml b/sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml new file mode 100644 index 0000000000000..40a78fb279ba3 --- /dev/null +++ b/sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml @@ -0,0 +1,12 @@ + + + + 1 + 2 + 3 + 4 + + + 5 + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 21122676c46be..5a901dadff94d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.xml.XmlOptions._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.xml.TestUtils._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1976,4 +1977,142 @@ class XmlSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(true, 10.1, -10, 10, "8E9D", "8E9F", Timestamp.valueOf("2015-01-01 00:00:00"))) } + + test("case sensitivity test - attributes-only object") { + val schemaCaseSensitive = new StructType() + .add("array", ArrayType( + new StructType() + .add("_Attr2", LongType) + .add("_VALUE", LongType) + .add("_aTTr2", LongType) + .add("_attr2", LongType))) + .add("struct", new StructType() + .add("_Attr1", LongType) + .add("_VALUE", LongType) + .add("_attr1", LongType)) + + val dfCaseSensitive = Seq( + Row( + Array( + Row(null, 2, null, 2), + Row(3, 3, null, null), + Row(null, 4, 4, null)), + Row(null, 1, 1) + ), + Row( + null, + Row(5, 5, null) + ) + ) + val schemaCaseInSensitive = new StructType() + .add("array", ArrayType(new StructType().add("_VALUE", LongType).add("_attr2", LongType))) + .add("struct", new StructType().add("_VALUE", LongType).add("_attr1", LongType)) + val dfCaseInsensitive = + Seq( + Row( + Array(Row(2, 2), Row(3, 3), Row(4, 4)), + Row(1, 1)), + Row(null, Row(5, 5))) + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val df = spark.read + .option("rowTag", "ROW") + .xml(getTestResourcePath(resDir + "attributes-case-sensitive.xml")) + assert(df.schema == (if (caseSensitive) schemaCaseSensitive else schemaCaseInSensitive)) + checkAnswer( + df, + if (caseSensitive) dfCaseSensitive else dfCaseInsensitive) + } + } + } + + testCaseSensitivity( + "basic", + writeData = Seq(Row(1L, null), Row(null, 2L)), + writeSchema = new StructType() + .add("A1", LongType) + .add("a1", LongType), + expectedSchema = new StructType() + .add("A1", LongType), + readDataCaseInsensitive = Seq(Row(1L), Row(2L))) + + testCaseSensitivity( + "nested struct", + writeData = Seq(Row(Row(1L), null), Row(null, Row(2L))), + writeSchema = new StructType() + .add("A1", new StructType().add("B1", LongType)) + .add("a1", new StructType().add("b1", LongType)), + expectedSchema = new StructType() + .add("A1", new StructType().add("B1", LongType)), + readDataCaseInsensitive = Seq(Row(Row(1L)), Row(Row(2L))) + ) + + testCaseSensitivity( + "convert fields into array", + writeData = Seq(Row(1L, 2L)), + writeSchema = new StructType() + .add("A1", LongType) + .add("a1", LongType), + expectedSchema = new StructType() + .add("A1", ArrayType(LongType)), + readDataCaseInsensitive = Seq(Row(Array(1L, 2L)))) + + testCaseSensitivity( + "basic array", + writeData = Seq(Row(Array(1L, 2L), Array(3L, 4L))), + writeSchema = new StructType() + .add("A1", ArrayType(LongType)) + .add("a1", ArrayType(LongType)), + expectedSchema = new StructType() + .add("A1", ArrayType(LongType)), + readDataCaseInsensitive = Seq(Row(Array(1L, 2L, 3L, 4L)))) + + testCaseSensitivity( + "nested array", + writeData = + Seq(Row(Array(Row(1L, 2L), Row(3L, 4L)), null), Row(null, Array(Row(5L, 6L), Row(7L, 8L)))), + writeSchema = new StructType() + .add("A1", ArrayType(new StructType().add("B1", LongType).add("d", LongType))) + .add("a1", ArrayType(new StructType().add("b1", LongType).add("c", LongType))), + expectedSchema = new StructType() + .add( + "A1", + ArrayType( + new StructType() + .add("B1", LongType) + .add("c", LongType) + .add("d", LongType))), + readDataCaseInsensitive = Seq( + Row(Array(Row(1L, null, 2L), Row(3L, null, 4L))), + Row(Array(Row(5L, 6L, null), Row(7L, 8L, null))))) + + def testCaseSensitivity( + name: String, + writeData: Seq[Row], + writeSchema: StructType, + expectedSchema: StructType, + readDataCaseInsensitive: Seq[Row]): Unit = { + test(s"case sensitivity test - $name") { + withTempDir { dir => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + spark + .createDataFrame(writeData.asJava, writeSchema) + .repartition(1) + .write + .option("rowTag", "ROW") + .format("xml") + .mode("overwrite") + .save(dir.getCanonicalPath) + } + + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val df = spark.read.option("rowTag", "ROW").xml(dir.getCanonicalPath) + assert(df.schema == (if (caseSensitive) writeSchema else expectedSchema)) + checkAnswer(df, if (caseSensitive) writeData else readDataCaseInsensitive) + } + } + } + } + } }