Skip to content

Commit

Permalink
[SPARK-45844][SQL] Implement case-insensitivity for XML
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR addresses the current lack of support for case-insensitive schema handling in XML file formats. Our approach now follows the `SQLConf` case insensitivity setting in both schema inference and file read operations.

We handle duplicate keys in the following behavior:
1. When we encounter duplicates (whether case-sensitive or not) in a row, we will convert them into an array and pick the first one we encounter as the array's name.
2. When we encounter duplicates across rows, we will also respect the first one we encounter

Keys of the map-type data are string types and are not treated as field names, thereby not requiring case-sensitivity checks.

### Why are the changes needed?

To keep consistent with other file formats and reduce maintenance efforts.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit tests

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#43722 from shujingyang-db/case-sensitive.

Lead-authored-by: Shujing Yang <[email protected]>
Co-authored-by: Shujing Yang <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
2 people authored and HyukjinKwon committed Nov 14, 2023
1 parent aa10ac7 commit 2cac768
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ case class SchemaOfXml(
private lazy val xmlFactory = xmlOptions.buildXmlFactory()

@transient
private lazy val xmlInferSchema = new XmlInferSchema(xmlOptions)
private lazy val xmlInferSchema =
new XmlInferSchema(xmlOptions, caseSensitive = SQLConf.get.caseSensitiveAnalysis)

@transient
private lazy val xml = child.eval().asInstanceOf[UTF8String]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ import org.apache.spark.SparkUpgradeException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, CaseInsensitiveMap, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -87,6 +88,14 @@ class StaxXmlParser(
}
}

private def getFieldNameToIndex(schema: StructType): Map[String, Int] = {
if (SQLConf.get.caseSensitiveAnalysis) {
schema.map(_.name).zipWithIndex.toMap
} else {
CaseInsensitiveMap(schema.map(_.name).zipWithIndex.toMap)
}
}

def parseStream(
inputStream: InputStream,
schema: StructType): Iterator[InternalRow] = {
Expand Down Expand Up @@ -274,7 +283,7 @@ class StaxXmlParser(
val convertedValuesMap = collection.mutable.Map.empty[String, Any]
val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
valuesMap.foreach { case (f, v) =>
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
val nameToIndex = getFieldNameToIndex(schema)
nameToIndex.get(f).foreach { i =>
convertedValuesMap(f) = convertTo(v, schema(i).dataType)
}
Expand Down Expand Up @@ -313,7 +322,7 @@ class StaxXmlParser(
// Here we merge both to a row.
val valuesMap = fieldsMap ++ attributesMap
valuesMap.foreach { case (f, v) =>
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
val nameToIndex = getFieldNameToIndex(schema)
nameToIndex.get(f).foreach { row(_) = v }
}

Expand All @@ -335,7 +344,7 @@ class StaxXmlParser(
rootAttributes: Array[Attribute] = Array.empty,
isRootAttributesOnly: Boolean = false): InternalRow = {
val row = new Array[Any](schema.length)
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
val nameToIndex = getFieldNameToIndex(schema)
// If there are attributes, then we process them first.
convertAttributes(rootAttributes, schema).toSeq.foreach { case (f, v) =>
nameToIndex.get(f).foreach { row(_) = v }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, PermissiveMode, Timest
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
import org.apache.spark.sql.types._

private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with Logging {
private[sql] class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean)
extends Serializable
with Logging {

private val decimalParser = ExprUtils.getDecimalParser(options.locale)

Expand Down Expand Up @@ -115,8 +117,7 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with
}
}

def infer(xml: String,
xsdSchema: Option[Schema] = None): Option[DataType] = {
def infer(xml: String, xsdSchema: Option[Schema] = None): Option[DataType] = {
try {
val xsd = xsdSchema.orElse(Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema))
xsd.foreach { schema =>
Expand Down Expand Up @@ -199,14 +200,50 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with
private def inferObject(
parser: XMLEventReader,
rootAttributes: Array[Attribute] = Array.empty): DataType = {
val builder = ArrayBuffer[StructField]()
val nameToDataType = collection.mutable.Map.empty[String, ArrayBuffer[DataType]]
/**
* Retrieves the field name with respect to the case sensitivity setting.
* We pick the first name we encountered.
*
* If case sensitivity is enabled, the original field name is returned.
* If not, the field name is managed in a case-insensitive map.
*
* For instance, if we encounter the following field names:
* foo, Foo, FOO
*
* In case-sensitive mode: we will infer three fields: foo, Foo, FOO
* In case-insensitive mode, we will infer an array named by foo
* (as it's the first one we encounter)
*/
val caseSensitivityOrdering: Ordering[String] = (x: String, y: String) =>
if (caseSensitive) {
x.compareTo(y)
} else {
x.compareToIgnoreCase(y)
}

val nameToDataType =
collection.mutable.TreeMap.empty[String, DataType](caseSensitivityOrdering)

def addOrUpdateType(fieldName: String, newType: DataType): Unit = {
val oldTypeOpt = nameToDataType.get(fieldName)
oldTypeOpt match {
// If the field name exists in the map,
// merge the type and infer the combined field as an array type if necessary
case Some(oldType) if !oldType.isInstanceOf[ArrayType] =>
nameToDataType.update(fieldName, ArrayType(compatibleType(oldType, newType)))
case Some(oldType) =>
nameToDataType.update(fieldName, compatibleType(oldType, newType))
case None =>
nameToDataType.put(fieldName, newType)
}
}

// If there are attributes, then we should process them first.
val rootValuesMap =
StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options)
rootValuesMap.foreach {
case (f, v) =>
nameToDataType += (f -> ArrayBuffer(inferFrom(v)))
addOrUpdateType(f, inferFrom(v))
}
var shouldStop = false
while (!shouldStop) {
Expand Down Expand Up @@ -239,14 +276,12 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with
}
// Add the field and datatypes so that we can check if this is ArrayType.
val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options)
val dataTypes = nameToDataType.getOrElse(field, ArrayBuffer.empty[DataType])
dataTypes += inferredType
nameToDataType += (field -> dataTypes)
addOrUpdateType(field, inferredType)

case c: Characters if !c.isWhiteSpace =>
// This can be an attribute-only object
val valueTagType = inferFrom(c.getData)
nameToDataType += options.valueTag -> ArrayBuffer(valueTagType)
addOrUpdateType(options.valueTag, valueTagType)

case _: EndElement =>
shouldStop = StaxXmlParserUtils.checkEndElement(parser)
Expand All @@ -258,25 +293,17 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with
// if it only consists of attributes and valueTags.
// If not, we will remove the valueTag field from the schema
val attributesOnly = nameToDataType.forall {
case (fieldName, dataTypes) =>
dataTypes.length == 1 &&
(fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix))
case (fieldName, _) =>
fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix)
}
if (!attributesOnly) {
nameToDataType -= options.valueTag
}
// We need to manually merges the fields having the sames so that
// This can be inferred as ArrayType.
nameToDataType.foreach {
case (field, dataTypes) if dataTypes.length > 1 =>
val elementType = dataTypes.reduceLeft(compatibleType)
builder += StructField(field, ArrayType(elementType), nullable = true)
case (field, dataTypes) =>
builder += StructField(field, dataTypes.head, nullable = true)
}

// Note: other code relies on this sorting for correctness, so don't remove it!
StructType(builder.sortBy(_.name).toArray)
StructType(nameToDataType.map{
case (name, dataType) => StructField(name, dataType)
}.toList.sortBy(_.name))
}

/**
Expand Down Expand Up @@ -384,7 +411,12 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with
/**
* Returns the most general data type for two given data types.
*/
def compatibleType(t1: DataType, t2: DataType): DataType = {
private[xml] def compatibleType(t1: DataType, t2: DataType): DataType = {

def normalize(name: String): String = {
if (caseSensitive) name else name.toLowerCase(Locale.ROOT)
}

// TODO: Optimise this logic.
findTightestCommonTypeOfTwo(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
Expand All @@ -406,10 +438,15 @@ private[sql] class XmlInferSchema(options: XmlOptions) extends Serializable with
}

case (StructType(fields1), StructType(fields2)) =>
val newFields = (fields1 ++ fields2).groupBy(_.name).map {
case (name, fieldTypes) =>
val newFields = (fields1 ++ fields2)
// normalize field name and pair it with original field
.map(field => (normalize(field.name), field))
.groupBy(_._1) // group by normalized field name
.map { case (_: String, fields: Array[(String, StructField)]) =>
val fieldTypes = fields.map(_._2)
val dataType = fieldTypes.map(_.dataType).reduce(compatibleType)
StructField(name, dataType, nullable = true)
// we pick up the first field name that we've encountered for the field
StructField(fields.head._2.name, dataType)
}
StructType(newFields.toArray.sortBy(_.name))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ object TextInputXmlDataSource extends XmlDataSource {
xml: Dataset[String],
parsedOptions: XmlOptions): StructType = {
SQLExecution.withSQLConfPropagated(xml.sparkSession) {
new XmlInferSchema(parsedOptions).infer(xml.rdd)
new XmlInferSchema(parsedOptions, xml.sparkSession.sessionState.conf.caseSensitiveAnalysis)
.infer(xml.rdd)
}
}

Expand Down Expand Up @@ -179,7 +180,9 @@ object MultiLineXmlDataSource extends XmlDataSource {
parsedOptions)
}
SQLExecution.withSQLConfPropagated(sparkSession) {
val schema = new XmlInferSchema(parsedOptions).infer(tokenRDD)
val schema =
new XmlInferSchema(parsedOptions, sparkSession.sessionState.conf.caseSensitiveAnalysis)
.infer(tokenRDD)
schema
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<?xml version="1.0"?>
<ROWSET>
<ROW>
<struct attr1="1">1</struct>
<array attr2="2">2</array>
<array Attr2="3">3</array>
<array aTTr2="4">4</array>
</ROW>
<ROW>
<struct Attr1="5">5</struct>
</ROW>
</ROWSET>
Loading

0 comments on commit 2cac768

Please sign in to comment.