diff --git a/benchmarks/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerBenchmark.scala b/benchmarks/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerBenchmark.scala new file mode 100644 index 0000000..d3fbbc2 --- /dev/null +++ b/benchmarks/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerBenchmark.scala @@ -0,0 +1,104 @@ +package com.github.mrpowers.spark.fast.tests + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.openjdk.jmh.annotations._ +import org.openjdk.jmh.infra.Blackhole + +import java.util.concurrent.TimeUnit +import scala.util.Try + +private class DatasetComparerBenchmark extends DatasetComparer { + def getSparkSession: SparkSession = { + val session = SparkSession + .builder() + .master("local") + .appName("spark session") + .getOrCreate() + session.sparkContext.setLogLevel("ERROR") + session + } + + @Benchmark + @BenchmarkMode(Array(Mode.SingleShotTime)) + @Fork(value = 2) + @Warmup(iterations = 10) + @Measurement(iterations = 10) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + def assertLargeDatasetEqualityV2(blackHole: Blackhole): Boolean = { + val spark = getSparkSession + val ds1 = spark.range(0, 1000000, 1, 8) + val ds3 = ds1 + + val result = Try(assertLargeDatasetEqualityV2(ds1, ds3)) + + blackHole.consume(result) + result.isSuccess + } + + @Benchmark + @BenchmarkMode(Array(Mode.SingleShotTime)) + @Fork(value = 2) + @Warmup(iterations = 10) + @Measurement(iterations = 10) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + def assertLargeDatasetEqualityV2WithSinglePrimaryKey(blackHole: Blackhole): Boolean = { + val spark = getSparkSession + val ds1 = spark.range(0, 1000000, 1, 8) + val ds3 = ds1 + + val result = Try(assertLargeDatasetEqualityV2(ds1, ds3, primaryKeys = Seq("id"))) + + blackHole.consume(result) + result.isSuccess + } + + @Benchmark + @BenchmarkMode(Array(Mode.SingleShotTime)) + @Fork(value = 2) + @Warmup(iterations = 10) + @Measurement(iterations = 10) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + def assertLargeDatasetEquality(blackHole: Blackhole): Boolean = { + val spark = getSparkSession + val ds1 = spark.range(0, 1000000, 1, 8) + val ds3 = ds1 + + val result = Try(assertLargeDatasetEquality(ds1, ds3)) + + blackHole.consume(result) + result.isSuccess + } + + @Benchmark + @BenchmarkMode(Array(Mode.SingleShotTime)) + @Fork(value = 2) + @Warmup(iterations = 10) + @Measurement(iterations = 10) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + def assertLargeDatasetEqualityV2WithCompositePrimaryKey2(blackHole: Blackhole): Boolean = { + val spark = getSparkSession + val ds1 = spark.range(0, 1000000, 1, 8).withColumn("id2", col("id") + 1) + val ds3 = ds1 + val result = Try(assertLargeDatasetEqualityV2(ds1, ds3, primaryKeys = Seq("id", "id2"))) + + blackHole.consume(result) + result.isSuccess + } + + @Benchmark + @BenchmarkMode(Array(Mode.SingleShotTime)) + @Fork(value = 2) + @Warmup(iterations = 10) + @Measurement(iterations = 10) + @OutputTimeUnit(TimeUnit.NANOSECONDS) + def assertLargeDatasetEqualityV2WithCompositePrimaryKey3(blackHole: Blackhole): Boolean = { + val spark = getSparkSession + val ds1 = spark.range(0, 1000000, 1, 8).withColumn("id2", col("id") + 1).withColumn("id3", col("id2") + 1) + val ds3 = ds1 + val result = Try(assertLargeDatasetEqualityV2(ds1, ds3, primaryKeys = Seq("id", "id2", "id3"))) + + blackHole.consume(result) + result.isSuccess + } +} diff --git a/build.sbt b/build.sbt index f9f8fb1..8bdcf85 100644 --- a/build.sbt +++ b/build.sbt @@ -38,14 +38,16 @@ lazy val commonSettings = Seq( else Seq.empty) }, libraryDependencies ++= Seq( - "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", - "org.scalatest" %% "scalatest" % "3.2.18" % "test" + "org.scalatest" %% "scalatest" % "3.2.18" % "test" ) ) lazy val core = (project in file("core")) .settings( commonSettings, + libraryDependencies ++= Seq( + "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" + ), name := "core", Compile / packageSrc / publishArtifact := true, Compile / packageDoc / publishArtifact := true @@ -56,7 +58,8 @@ lazy val benchmarks = (project in file("benchmarks")) .settings(commonSettings) .settings( libraryDependencies ++= Seq( - "org.openjdk.jmh" % "jmh-generator-annprocess" % "1.37" // required for jmh IDEA plugin. Make sure this version matches sbt-jmh version! + "org.apache.spark" %% "spark-sql" % sparkVersion % "compile", + "org.openjdk.jmh" % "jmh-generator-annprocess" % "1.37" // required for jmh IDEA plugin. Make sure this version matches sbt-jmh version! ), name := "benchmarks", publish / skip := true @@ -121,4 +124,4 @@ updateOptions := updateOptions.value.withLatestSnapshots(false) import xerial.sbt.Sonatype.sonatypeCentralHost -ThisBuild / sonatypeCredentialHost := sonatypeCentralHost \ No newline at end of file +ThisBuild / sonatypeCredentialHost := sonatypeCentralHost diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala index 0ff191d..f06d830 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala @@ -1,10 +1,11 @@ package com.github.mrpowers.spark.fast.tests +import com.github.mrpowers.spark.fast.tests.DatasetUtils.DatasetOps import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import scala.reflect.ClassTag @@ -72,7 +73,7 @@ Expected DataFrame Row Count: '$expectedCount' val e = expectedDS.collect().toSeq if (!a.approximateSameElements(e, equals)) { val arr = ("Actual Content", "Expected Content") - val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate) + val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, Left(a -> e), truncate) throw DatasetContentMismatch(msg) } } @@ -152,6 +153,89 @@ Expected DataFrame Row Count: '$expectedCount' } } + /** + * Raises an error unless `actualDS` and `expectedDS` are equal. It is recommended to provide `primaryKeys` to ensure accurate and efficient + * comparison of rows. If primary key is not provided, will try to compare the rows based on their row number. This requires both datasets to be + * partitioned in the same way and become unreliable when shuffling is involved. + * @param primaryKeys + * unique identifier for each row to ensure accurate comparison of rows + * @param checkKeyUniqueness + * if true, will check if the primary key is actually unique + */ + def assertLargeDatasetEqualityV2[T: ClassTag]( + actualDS: Dataset[T], + expectedDS: Dataset[T], + equals: Either[(T, T) => Boolean, Option[Column]] = Right(None), + ignoreNullable: Boolean = false, + ignoreColumnNames: Boolean = false, + ignoreColumnOrder: Boolean = false, + ignoreMetadata: Boolean = true, + checkKeyUniqueness: Boolean = false, + primaryKeys: Seq[String] = Seq.empty, + truncate: Int = 500 + ): Unit = { + // first check if the schemas are equal + SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) + val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS + assertLargeDatasetContentEqualityV2(actual, expectedDS, equals, primaryKeys, checkKeyUniqueness, truncate) + } + + def assertLargeDatasetContentEqualityV2[T: ClassTag]( + ds1: Dataset[T], + ds2: Dataset[T], + equals: Either[(T, T) => Boolean, Option[Column]], + primaryKeys: Seq[String], + checkKeyUniqueness: Boolean, + truncate: Int + ): Unit = { + try { + ds1.cache() + ds2.cache() + + val actualCount = ds1.count + val expectedCount = ds2.count + + if (actualCount != expectedCount) { + throw DatasetCountMismatch(countMismatchMessage(actualCount, expectedCount)) + } + + if (primaryKeys.nonEmpty && checkKeyUniqueness) { + assert(ds1.isKeyUnique(primaryKeys), "Primary key is not unique in actual dataset") + assert(ds2.isKeyUnique(primaryKeys), "Primary key is not unique in expected dataset") + } + + val joinedDf = ds1 + .joinPair(ds2, primaryKeys) + + val unequalDS = equals match { + case Left(customEquals) => + joinedDf.filter((p: (T, T)) => + // dataset joinWith implicitly return null each side for missing values from outer join even for primitive types + p match { + case (null, null) => false + case (null, _) => true + case (_, null) => true + case (l, r) => !customEquals(l, r) + } + ) + + case Right(equalExprOption) => + joinedDf.filter(equalExprOption.getOrElse(col("_1") =!= col("_2"))) + } + + if (!unequalDS.isEmpty) { + val joined = Right(unequalDS.take(truncate).toSeq) + val arr = ("Actual Content", "Expected Content") + val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, joined, truncate) + throw DatasetContentMismatch(msg) + } + + } finally { + ds1.unpersist() + ds2.unpersist() + } + } + def assertApproximateDataFrameEquality( actualDF: DataFrame, expectedDF: DataFrame, diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetUtils.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetUtils.scala new file mode 100644 index 0000000..6e5a141 --- /dev/null +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetUtils.scala @@ -0,0 +1,58 @@ +package com.github.mrpowers.spark.fast.tests + +import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, TypedColumn} + +import scala.reflect.ClassTag + +private object DatasetUtils { + implicit class DatasetOps[T: ClassTag](ds: Dataset[T]) { + def zipWithIndex(indexName: String): DataFrame = ds + .orderBy() + .withColumn(indexName, row_number().over(Window.orderBy(monotonically_increasing_id()))) + .select(ds.columns.map(col) :+ col(indexName): _*) + + /** + * Check if the primary key is actually unique + */ + def isKeyUnique(primaryKey: Seq[String]): Boolean = + ds.select(primaryKey.map(col): _*).distinct.count == ds.count + + def joinPair[P: ClassTag]( + other: Dataset[P], + primaryKeys: Seq[String] + ): Dataset[(T, P)] = { + if (primaryKeys.nonEmpty) { + ds + .as("l") + .joinWith(other.as("r"), primaryKeys.map(k => col(s"l.$k") === col(s"r.$k")).reduce(_ && _), "full_outer") + } else { + val indexName = s"index_${java.util.UUID.randomUUID}" + val joined = ds + .zipWithIndex(indexName) + .alias("l") + .joinWith(other.zipWithIndex(indexName).alias("r"), col(s"l.$indexName") === col(s"r.$indexName"), "full_outer") + + joined + .select( + encoderToCol("_1", ds.schema, ds.encoder, Seq(indexName)), + encoderToCol("_2", other.schema, other.encoder, Seq(indexName)) + ) + } + } + } + + private def encoderToCol[P: ClassTag](colName: String, schema: StructType, encoder: Encoder[P], key: Seq[String]): TypedColumn[Any, P] = { + val columns = schema.names.map(n => col(s"$colName.$n")) // name from encoder is not reliable + val isRowType = implicitly[ClassTag[P]].runtimeClass == classOf[Row] + val unTypedColumn = + if (columns.length == 1 && !isRowType) + columns.head + else + when(key.map(k => col(s"$colName.$k").isNull).reduce(_ && _), lit(null)).otherwise(struct(columns: _*)) + + unTypedColumn.as(colName).as[P](encoder) + } +} diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index b4b464c..dc1662b 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -5,6 +5,7 @@ import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.Row +import scala.annotation.tailrec import scala.reflect.ClassTag object ProductUtil { @@ -20,8 +21,7 @@ object ProductUtil { } private[mrpowers] def showProductDiff[T: ClassTag]( header: (String, String), - actual: Seq[T], - expected: Seq[T], + data: Either[(Seq[T], Seq[T]), Seq[(T, T)]], truncate: Int = 20, minColWidth: Int = 3 ): String = { @@ -33,7 +33,10 @@ object ProductUtil { val sb = new StringBuilder - val fullJoin = actual.zipAll(expected, null, null) + val fullJoin = data match { + case Left((actual, expected)) => actual.zipAll(expected, null, null) + case Right(joined) => joined + } val diff = fullJoin.map { case (actualRow, expectedRow) => if (actualRow == expectedRow) { diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala index 32ceaf7..92b66dd 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala @@ -14,8 +14,7 @@ object SchemaComparer { private def betterSchemaMismatchMessage(actualSchema: StructType, expectedSchema: StructType): String = { showProductDiff( ("Actual Schema", "Expected Schema"), - actualSchema.fields, - expectedSchema.fields, + Left(actualSchema.fields.toSeq -> expectedSchema.fields.toSeq), truncate = 200 ) } diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerV2Test.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerV2Test.scala new file mode 100644 index 0000000..9c0c01a --- /dev/null +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerV2Test.scala @@ -0,0 +1,482 @@ +package com.github.mrpowers.spark.fast.tests + +import org.apache.spark.sql.types._ +import SparkSessionExt._ +import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch +import com.github.mrpowers.spark.fast.tests.TestUtilsExt.ExceptionOps +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.functions.col +import org.scalatest.freespec.AnyFreeSpec + +class DatasetComparerV2Test extends AnyFreeSpec with DatasetComparer { + lazy val spark: SparkSession = { + val session = SparkSession + .builder() + .master("local") + .appName("spark session") + .config("spark.sql.shuffle.partitions", "3") + .getOrCreate() + session.sparkContext.setLogLevel("ERROR") + session + } + + "checkDatasetEquality" - { + import spark.implicits._ + + "can compare DataFrame" in { + val sourceDF = spark.createDF( + List( + (1, "text"), + (5, "text") + ), + List(("number", IntegerType, true), ("text", StringType, true)) + ) + + val expectedDF = spark.createDF( + List( + (1, "text"), + (5, "text") + ), + List(("number", IntegerType, true), ("text", StringType, true)) + ) + + assertLargeDatasetEqualityV2(sourceDF, expectedDF) + } + + "can compare Dataset[Array[_]]" in { + val sourceDS = Seq( + Array("apple", "banana", "cherry"), + Array("dog", "cat"), + Array("red", "green", "blue") + ).toDS + + val expectedDS = Seq( + Array("apple", "banana", "cherry"), + Array("dog", "cat"), + Array("red", "green", "blue") + ).toDS + + assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Left((a1: Array[String], a2: Array[String]) => a1.mkString == a2.mkString)) + } + + "can compare Dataset[Map[_]]" in { + val sourceDS = Seq( + Map("apple" -> "banana", "apple1" -> "banana1"), + Map("apple" -> "banana", "apple1" -> "banana1") + ).toDS + + val expectedDS = Seq( + Map("apple" -> "banana", "apple1" -> "banana1"), + Map("apple" -> "banana", "apple1" -> "banana1") + ).toDS + + assertLargeDatasetEqualityV2( + sourceDS, + expectedDS, + equals = Left((a1: Map[String, String], a2: Map[String, String]) => a1.mkString == a2.mkString) + ) + } + + "does nothing if the Datasets have the same schemas and content" in { + val sourceDS = spark.createDataset[Person]( + Seq( + Person("Alice", 12), + Person("Bob", 17) + ) + ) + + val expectedDS = spark.createDataset[Person]( + Seq( + Person("Alice", 12), + Person("Bob", 17) + ) + ) + + assertLargeDatasetEqualityV2(sourceDS, expectedDS) + } + + "works with DataFrames that have ArrayType columns" in { + val sourceDF = spark.createDF( + List( + (1, Array("word1", "blah")), + (5, Array("hi", "there")) + ), + List( + ("number", IntegerType, true), + ("words", ArrayType(StringType, true), true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, Array("word1", "blah")), + (5, Array("hi", "there")) + ), + List( + ("number", IntegerType, true), + ("words", ArrayType(StringType, true), true) + ) + ) + + assertLargeDatasetEqualityV2(sourceDF, expectedDF) + } + + "throws an error if the DataFrames have different schemas" in { + val nestedSchema = StructType( + Seq( + StructField( + "attributes", + StructType( + Seq( + StructField("PostCode", IntegerType, nullable = true) + ) + ), + nullable = true + ) + ) + ) + + val nestedSchema2 = StructType( + Seq( + StructField( + "attributes", + StructType( + Seq( + StructField("PostCode", StringType, nullable = true) + ) + ), + nullable = true + ) + ) + ) + + val sourceDF = spark.createDF( + List( + (1, 2.0, null), + (5, 3.0, null) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true), + ("nestedField", nestedSchema, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", null, 1L), + (5, "word", null, 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("nestedField", nestedSchema2, true), + ("long", LongType, true) + ) + ) + + intercept[DatasetSchemaMismatch] { + assertLargeDatasetEqualityV2(sourceDF, expectedDF) + } + } + + "throws an error if the DataFrames content is different" in { + val sourceDF = Seq( + (1), (5), (7), (1), (1) + ).toDF("number") + + val expectedDF = Seq( + (10), (5), (3), (7), (1) + ).toDF("number") + + intercept[DatasetContentMismatch] { + assertLargeDatasetEqualityV2(sourceDF, expectedDF) + } + } + + "throws an error if the Dataset content is different" in { + val sourceDS = spark.createDataset[Person]( + Seq( + Person("Alice", 12), + Person("Bob", 17) + ) + ) + + val expectedDS = spark.createDataset[Person]( + Seq( + Person("Frank", 10), + Person("Lucy", 5) + ) + ) + + intercept[DatasetContentMismatch] { + assertLargeDatasetEqualityV2(sourceDS, expectedDS) + } + } + + "succeeds if custom comparator returns true" in { + val sourceDS = spark.createDataset[Person]( + Seq( + Person("bob", 1), + Person("alice", 5) + ) + ) + val expectedDS = spark.createDataset[Person]( + Seq( + Person("Bob", 1), + Person("Alice", 5) + ) + ) + assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Left((p1: Person, p2: Person) => Person.caseInsensitivePersonEquals(p1, p2))) + } + + "fails if custom comparator for returns false" in { + val sourceDS = spark.createDataset[Person]( + Seq( + Person("bob", 10), + Person("alice", 5) + ) + ) + val expectedDS = spark.createDataset[Person]( + Seq( + Person("Bob", 1), + Person("Alice", 5) + ) + ) + + intercept[DatasetContentMismatch] { + assertLargeDatasetEqualityV2(sourceDS, expectedDS, equals = Left((p1: Person, p2: Person) => Person.caseInsensitivePersonEquals(p1, p2))) + } + } + + } + + "assertLargeDatasetEquality" - { + import spark.implicits._ + + "ignores the nullable flag when making DataFrame comparisons" in { + val sourceDF = spark.createDF( + List( + (1), + (5) + ), + List(("number", IntegerType, false)) + ) + + val expectedDF = spark.createDF( + List( + (1), + (5) + ), + List(("number", IntegerType, true)) + ) + + assertLargeDatasetEqualityV2(sourceDF, expectedDF, ignoreNullable = true) + } + + "should not ignore nullable if ignoreNullable is false" in { + + val sourceDF = spark.createDF( + List( + (1), + (5) + ), + List(("number", IntegerType, false)) + ) + + val expectedDF = spark.createDF( + List( + (1), + (5) + ), + List(("number", IntegerType, true)) + ) + + intercept[DatasetSchemaMismatch] { + assertLargeDatasetEqualityV2(sourceDF, expectedDF) + } + } + + "throws an error DataFrames have a different number of rows" in { + val sourceDF = spark.createDF( + List( + (1), + (5) + ), + List(("number", IntegerType, true)) + ) + val expectedDF = spark.createDF( + List( + (1), + (5), + (10) + ), + List(("number", IntegerType, true)) + ) + + intercept[DatasetCountMismatch] { + assertLargeDatasetEqualityV2(sourceDF, expectedDF) + } + } + + "can performed DataFrame comparisons with unordered column" in { + val sourceDF = spark.createDF( + List( + (1, "word"), + (5, "word") + ), + List( + ("number", IntegerType, true), + ("word", StringType, true) + ) + ) + val expectedDF = spark.createDF( + List( + ("word", 1), + ("word", 5) + ), + List( + ("word", StringType, true), + ("number", IntegerType, true) + ) + ) + assertLargeDatasetEqualityV2(sourceDF, expectedDF, ignoreColumnOrder = true) + } + + "can performed Dataset comparisons with unordered column" in { + val ds1 = Seq( + Person("juan", 5), + Person("bob", 1), + Person("li", 49), + Person("alice", 5) + ).toDS + + val ds2 = Seq( + Person("juan", 5), + Person("bob", 1), + Person("li", 49), + Person("alice", 5) + ).toDS.select("age", "name").as(ds1.encoder) + + assertLargeDatasetEqualityV2(ds2, ds1, ignoreColumnOrder = true) + } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertLargeDatasetEqualityV2(sourceDF, expectedDF) + } + + e.assertColorDiff(Seq("float", "DoubleType", "MISSING"), Seq("word", "StringType", "StructField(long,LongType,true,{})")) + } + + "can performed Dataset comparisons and ignore metadata" in { + val ds1 = Seq( + Person("juan", 5), + Person("bob", 1), + Person("li", 49), + Person("alice", 5) + ).toDS + .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build())) + .as[Person] + + val ds2 = Seq( + Person("juan", 5), + Person("bob", 1), + Person("li", 49), + Person("alice", 5) + ).toDS + .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build())) + .as[Person] + + assertLargeDatasetEqualityV2(ds2, ds1) + } + + "can performed Dataset comparisons and compare metadata" in { + val ds1 = Seq( + Person("juan", 5), + Person("bob", 1), + Person("li", 49), + Person("alice", 5) + ).toDS + .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the person").build())) + .as[Person] + + val ds2 = Seq( + Person("juan", 5), + Person("bob", 1), + Person("li", 49), + Person("alice", 5) + ).toDS + .withColumn("name", col("name").as("name", new MetadataBuilder().putString("description", "name of the individual").build())) + .as[Person] + + intercept[DatasetSchemaMismatch] { + assertLargeDatasetEqualityV2(ds2, ds1, ignoreMetadata = false) + } + } + + "can handle when there are unmatched row of Product Type" in { + val sourceDS = spark.createDataset[Person]( + Seq( + Person("Alice", 12), + Person("Bob", 17) + ) + ) + + val expectedDS = spark.createDataset[Person]( + Seq( + Person("Alice", 12), + Person("Bob1", 17) + ) + ) + + intercept[DatasetContentMismatch] { + assertLargeDatasetEqualityV2( + sourceDS, + expectedDS, + ignoreMetadata = false, + equals = Left((p1: Person, p2: Person) => p1.age == p2.age && p1.name == p2.name) + ) + } + + } + + "can handle when there are unmatched rows of Primitive Type" in { + val sourceDS = spark.range(0, 10, 1) + val expectedDS = spark.range(0, 20, 2) + + intercept[DatasetContentMismatch] { + assertLargeDatasetEqualityV2( + sourceDS, + expectedDS, + ignoreMetadata = false, + equals = Left((p1: java.lang.Long, p2: java.lang.Long) => p1 == p2) + ) + } + + } + } +}