diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index c581643460f65..27c0a09fa1f06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -189,7 +189,8 @@ case class SchemaOfXml( private lazy val xmlFactory = xmlOptions.buildXmlFactory() @transient - private lazy val xmlInferSchema = new XmlInferSchema(xmlOptions) + private lazy val xmlInferSchema = + new XmlInferSchema(xmlOptions, caseSensitive = SQLConf.get.caseSensitiveAnalysis) @transient private lazy val xml = child.eval().asInstanceOf[UTF8String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 77a0bd1dff179..754b54ce1575c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -35,10 +35,11 @@ import org.apache.spark.SparkUpgradeException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, CaseInsensitiveMap, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -87,6 +88,14 @@ class StaxXmlParser( } } + private def getFieldNameToIndex(schema: StructType): Map[String, Int] = { + if (SQLConf.get.caseSensitiveAnalysis) { + schema.map(_.name).zipWithIndex.toMap + } else { + CaseInsensitiveMap(schema.map(_.name).zipWithIndex.toMap) + } + } + def parseStream( inputStream: InputStream, schema: StructType): Iterator[InternalRow] = { @@ -274,7 +283,7 @@ class StaxXmlParser( val convertedValuesMap = collection.mutable.Map.empty[String, Any] val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options) valuesMap.foreach { case (f, v) => - val nameToIndex = schema.map(_.name).zipWithIndex.toMap + val nameToIndex = getFieldNameToIndex(schema) nameToIndex.get(f).foreach { i => convertedValuesMap(f) = convertTo(v, schema(i).dataType) } @@ -313,7 +322,7 @@ class StaxXmlParser( // Here we merge both to a row. val valuesMap = fieldsMap ++ attributesMap valuesMap.foreach { case (f, v) => - val nameToIndex = schema.map(_.name).zipWithIndex.toMap + val nameToIndex = getFieldNameToIndex(schema) nameToIndex.get(f).foreach { row(_) = v } } @@ -335,7 +344,7 @@ class StaxXmlParser( rootAttributes: Array[Attribute] = Array.empty, isRootAttributesOnly: Boolean = false): InternalRow = { val row = new Array[Any](schema.length) - val nameToIndex = schema.map(_.name).zipWithIndex.toMap + val nameToIndex = getFieldNameToIndex(schema) // If there are attributes, then we process them first. convertAttributes(rootAttributes, schema).toSeq.foreach { case (f, v) => nameToIndex.get(f).foreach { row(_) = v } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index 777dd69fd7fa0..25f33e7f1bbdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -36,7 +36,9 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, PermissiveMode, Timest import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.types._ -private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with Logging { +private[sql] class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) + extends Serializable + with Logging { private val decimalParser = ExprUtils.getDecimalParser(options.locale) @@ -115,8 +117,7 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with } } - def infer(xml: String, - xsdSchema: Option[Schema] = None): Option[DataType] = { + def infer(xml: String, xsdSchema: Option[Schema] = None): Option[DataType] = { try { val xsd = xsdSchema.orElse(Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)) xsd.foreach { schema => @@ -199,14 +200,50 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with private def inferObject( parser: XMLEventReader, rootAttributes: Array[Attribute] = Array.empty): DataType = { - val builder = ArrayBuffer[StructField]() - val nameToDataType = collection.mutable.Map.empty[String, ArrayBuffer[DataType]] + /** + * Retrieves the field name with respect to the case sensitivity setting. + * We pick the first name we encountered. + * + * If case sensitivity is enabled, the original field name is returned. + * If not, the field name is managed in a case-insensitive map. + * + * For instance, if we encounter the following field names: + * foo, Foo, FOO + * + * In case-sensitive mode: we will infer three fields: foo, Foo, FOO + * In case-insensitive mode, we will infer an array named by foo + * (as it's the first one we encounter) + */ + val caseSensitivityOrdering: Ordering[String] = (x: String, y: String) => + if (caseSensitive) { + x.compareTo(y) + } else { + x.compareToIgnoreCase(y) + } + + val nameToDataType = + collection.mutable.TreeMap.empty[String, DataType](caseSensitivityOrdering) + + def addOrUpdateType(fieldName: String, newType: DataType): Unit = { + val oldTypeOpt = nameToDataType.get(fieldName) + oldTypeOpt match { + // If the field name exists in the map, + // merge the type and infer the combined field as an array type if necessary + case Some(oldType) if !oldType.isInstanceOf[ArrayType] => + nameToDataType.update(fieldName, ArrayType(compatibleType(oldType, newType))) + case Some(oldType) => + nameToDataType.update(fieldName, compatibleType(oldType, newType)) + case None => + nameToDataType.put(fieldName, newType) + } + } + // If there are attributes, then we should process them first. val rootValuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) rootValuesMap.foreach { case (f, v) => - nameToDataType += (f -> ArrayBuffer(inferFrom(v))) + addOrUpdateType(f, inferFrom(v)) } var shouldStop = false while (!shouldStop) { @@ -239,14 +276,12 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with } // Add the field and datatypes so that we can check if this is ArrayType. val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) - val dataTypes = nameToDataType.getOrElse(field, ArrayBuffer.empty[DataType]) - dataTypes += inferredType - nameToDataType += (field -> dataTypes) + addOrUpdateType(field, inferredType) case c: Characters if !c.isWhiteSpace => // This can be an attribute-only object val valueTagType = inferFrom(c.getData) - nameToDataType += options.valueTag -> ArrayBuffer(valueTagType) + addOrUpdateType(options.valueTag, valueTagType) case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser) @@ -258,25 +293,17 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with // if it only consists of attributes and valueTags. // If not, we will remove the valueTag field from the schema val attributesOnly = nameToDataType.forall { - case (fieldName, dataTypes) => - dataTypes.length == 1 && - (fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix)) + case (fieldName, _) => + fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix) } if (!attributesOnly) { nameToDataType -= options.valueTag } - // We need to manually merges the fields having the sames so that - // This can be inferred as ArrayType. - nameToDataType.foreach { - case (field, dataTypes) if dataTypes.length > 1 => - val elementType = dataTypes.reduceLeft(compatibleType) - builder += StructField(field, ArrayType(elementType), nullable = true) - case (field, dataTypes) => - builder += StructField(field, dataTypes.head, nullable = true) - } // Note: other code relies on this sorting for correctness, so don't remove it! - StructType(builder.sortBy(_.name).toArray) + StructType(nameToDataType.map{ + case (name, dataType) => StructField(name, dataType) + }.toList.sortBy(_.name)) } /** @@ -384,7 +411,12 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType): DataType = { + private[xml] def compatibleType(t1: DataType, t2: DataType): DataType = { + + def normalize(name: String): String = { + if (caseSensitive) name else name.toLowerCase(Locale.ROOT) + } + // TODO: Optimise this logic. findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. @@ -406,10 +438,15 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with } case (StructType(fields1), StructType(fields2)) => - val newFields = (fields1 ++ fields2).groupBy(_.name).map { - case (name, fieldTypes) => + val newFields = (fields1 ++ fields2) + // normalize field name and pair it with original field + .map(field => (normalize(field.name), field)) + .groupBy(_._1) // group by normalized field name + .map { case (_: String, fields: Array[(String, StructField)]) => + val fieldTypes = fields.map(_._2) val dataType = fieldTypes.map(_.dataType).reduce(compatibleType) - StructField(name, dataType, nullable = true) + // we pick up the first field name that we've encountered for the field + StructField(fields.head._2.name, dataType) } StructType(newFields.toArray.sortBy(_.name)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala index b09be84130abb..4b3c82bd83bc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlDataSource.scala @@ -122,7 +122,8 @@ object TextInputXmlDataSource extends XmlDataSource { xml: Dataset[String], parsedOptions: XmlOptions): StructType = { SQLExecution.withSQLConfPropagated(xml.sparkSession) { - new XmlInferSchema(parsedOptions).infer(xml.rdd) + new XmlInferSchema(parsedOptions, xml.sparkSession.sessionState.conf.caseSensitiveAnalysis) + .infer(xml.rdd) } } @@ -179,7 +180,9 @@ object MultiLineXmlDataSource extends XmlDataSource { parsedOptions) } SQLExecution.withSQLConfPropagated(sparkSession) { - val schema = new XmlInferSchema(parsedOptions).infer(tokenRDD) + val schema = + new XmlInferSchema(parsedOptions, sparkSession.sessionState.conf.caseSensitiveAnalysis) + .infer(tokenRDD) schema } } diff --git a/sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml b/sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml new file mode 100644 index 0000000000000..40a78fb279ba3 --- /dev/null +++ b/sql/core/src/test/resources/test-data/xml-resources/attributes-case-sensitive.xml @@ -0,0 +1,12 @@ + + + + 1 + 2 + 3 + 4 + + + 5 + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 21122676c46be..5a901dadff94d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.xml.XmlOptions._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.xml.TestUtils._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1976,4 +1977,142 @@ class XmlSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(true, 10.1, -10, 10, "8E9D", "8E9F", Timestamp.valueOf("2015-01-01 00:00:00"))) } + + test("case sensitivity test - attributes-only object") { + val schemaCaseSensitive = new StructType() + .add("array", ArrayType( + new StructType() + .add("_Attr2", LongType) + .add("_VALUE", LongType) + .add("_aTTr2", LongType) + .add("_attr2", LongType))) + .add("struct", new StructType() + .add("_Attr1", LongType) + .add("_VALUE", LongType) + .add("_attr1", LongType)) + + val dfCaseSensitive = Seq( + Row( + Array( + Row(null, 2, null, 2), + Row(3, 3, null, null), + Row(null, 4, 4, null)), + Row(null, 1, 1) + ), + Row( + null, + Row(5, 5, null) + ) + ) + val schemaCaseInSensitive = new StructType() + .add("array", ArrayType(new StructType().add("_VALUE", LongType).add("_attr2", LongType))) + .add("struct", new StructType().add("_VALUE", LongType).add("_attr1", LongType)) + val dfCaseInsensitive = + Seq( + Row( + Array(Row(2, 2), Row(3, 3), Row(4, 4)), + Row(1, 1)), + Row(null, Row(5, 5))) + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val df = spark.read + .option("rowTag", "ROW") + .xml(getTestResourcePath(resDir + "attributes-case-sensitive.xml")) + assert(df.schema == (if (caseSensitive) schemaCaseSensitive else schemaCaseInSensitive)) + checkAnswer( + df, + if (caseSensitive) dfCaseSensitive else dfCaseInsensitive) + } + } + } + + testCaseSensitivity( + "basic", + writeData = Seq(Row(1L, null), Row(null, 2L)), + writeSchema = new StructType() + .add("A1", LongType) + .add("a1", LongType), + expectedSchema = new StructType() + .add("A1", LongType), + readDataCaseInsensitive = Seq(Row(1L), Row(2L))) + + testCaseSensitivity( + "nested struct", + writeData = Seq(Row(Row(1L), null), Row(null, Row(2L))), + writeSchema = new StructType() + .add("A1", new StructType().add("B1", LongType)) + .add("a1", new StructType().add("b1", LongType)), + expectedSchema = new StructType() + .add("A1", new StructType().add("B1", LongType)), + readDataCaseInsensitive = Seq(Row(Row(1L)), Row(Row(2L))) + ) + + testCaseSensitivity( + "convert fields into array", + writeData = Seq(Row(1L, 2L)), + writeSchema = new StructType() + .add("A1", LongType) + .add("a1", LongType), + expectedSchema = new StructType() + .add("A1", ArrayType(LongType)), + readDataCaseInsensitive = Seq(Row(Array(1L, 2L)))) + + testCaseSensitivity( + "basic array", + writeData = Seq(Row(Array(1L, 2L), Array(3L, 4L))), + writeSchema = new StructType() + .add("A1", ArrayType(LongType)) + .add("a1", ArrayType(LongType)), + expectedSchema = new StructType() + .add("A1", ArrayType(LongType)), + readDataCaseInsensitive = Seq(Row(Array(1L, 2L, 3L, 4L)))) + + testCaseSensitivity( + "nested array", + writeData = + Seq(Row(Array(Row(1L, 2L), Row(3L, 4L)), null), Row(null, Array(Row(5L, 6L), Row(7L, 8L)))), + writeSchema = new StructType() + .add("A1", ArrayType(new StructType().add("B1", LongType).add("d", LongType))) + .add("a1", ArrayType(new StructType().add("b1", LongType).add("c", LongType))), + expectedSchema = new StructType() + .add( + "A1", + ArrayType( + new StructType() + .add("B1", LongType) + .add("c", LongType) + .add("d", LongType))), + readDataCaseInsensitive = Seq( + Row(Array(Row(1L, null, 2L), Row(3L, null, 4L))), + Row(Array(Row(5L, 6L, null), Row(7L, 8L, null))))) + + def testCaseSensitivity( + name: String, + writeData: Seq[Row], + writeSchema: StructType, + expectedSchema: StructType, + readDataCaseInsensitive: Seq[Row]): Unit = { + test(s"case sensitivity test - $name") { + withTempDir { dir => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + spark + .createDataFrame(writeData.asJava, writeSchema) + .repartition(1) + .write + .option("rowTag", "ROW") + .format("xml") + .mode("overwrite") + .save(dir.getCanonicalPath) + } + + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + val df = spark.read.option("rowTag", "ROW").xml(dir.getCanonicalPath) + assert(df.schema == (if (caseSensitive) writeSchema else expectedSchema)) + checkAnswer(df, if (caseSensitive) writeData else readDataCaseInsensitive) + } + } + } + } + } }