-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft assertLargeDatasetEqualityV2 #181
Draft
zeotuan
wants to merge
9
commits into
mrpowers-io:main
Choose a base branch
from
zeotuan:nonRddAssert
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
8e470d7
assertLargeDatasetEqualityV2
zeotuan 3e6b254
use assert API for benchmark
zeotuan 90ec1e1
Add benchmark with join column
zeotuan 37eca14
Add benchmark with multiple join column
zeotuan ba7c5cb
Add benchmark with multiple join column
zeotuan cd173ef
Improve equal comparison option
zeotuan a528229
Fix assert not work on DF with single column
zeotuan 9c589f4
Typed outer join
zeotuan 6d68eb7
Use builtin joinWith
zeotuan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
104 changes: 104 additions & 0 deletions
104
...hmarks/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerBenchmark.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
package com.github.mrpowers.spark.fast.tests | ||
|
||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.functions.col | ||
import org.openjdk.jmh.annotations._ | ||
import org.openjdk.jmh.infra.Blackhole | ||
|
||
import java.util.concurrent.TimeUnit | ||
import scala.util.Try | ||
|
||
private class DatasetComparerBenchmark extends DatasetComparer { | ||
def getSparkSession: SparkSession = { | ||
val session = SparkSession | ||
.builder() | ||
.master("local") | ||
.appName("spark session") | ||
.getOrCreate() | ||
session.sparkContext.setLogLevel("ERROR") | ||
session | ||
} | ||
|
||
@Benchmark | ||
@BenchmarkMode(Array(Mode.SingleShotTime)) | ||
@Fork(value = 2) | ||
@Warmup(iterations = 10) | ||
@Measurement(iterations = 10) | ||
@OutputTimeUnit(TimeUnit.NANOSECONDS) | ||
def assertLargeDatasetEqualityV2(blackHole: Blackhole): Boolean = { | ||
val spark = getSparkSession | ||
val ds1 = spark.range(0, 1000000, 1, 8) | ||
val ds3 = ds1 | ||
|
||
val result = Try(assertLargeDatasetEqualityV2(ds1, ds3)) | ||
|
||
blackHole.consume(result) | ||
result.isSuccess | ||
} | ||
|
||
@Benchmark | ||
@BenchmarkMode(Array(Mode.SingleShotTime)) | ||
@Fork(value = 2) | ||
@Warmup(iterations = 10) | ||
@Measurement(iterations = 10) | ||
@OutputTimeUnit(TimeUnit.NANOSECONDS) | ||
def assertLargeDatasetEqualityV2WithSinglePrimaryKey(blackHole: Blackhole): Boolean = { | ||
val spark = getSparkSession | ||
val ds1 = spark.range(0, 1000000, 1, 8) | ||
val ds3 = ds1 | ||
|
||
val result = Try(assertLargeDatasetEqualityV2(ds1, ds3, primaryKeys = Seq("id"))) | ||
|
||
blackHole.consume(result) | ||
result.isSuccess | ||
} | ||
|
||
@Benchmark | ||
@BenchmarkMode(Array(Mode.SingleShotTime)) | ||
@Fork(value = 2) | ||
@Warmup(iterations = 10) | ||
@Measurement(iterations = 10) | ||
@OutputTimeUnit(TimeUnit.NANOSECONDS) | ||
def assertLargeDatasetEquality(blackHole: Blackhole): Boolean = { | ||
val spark = getSparkSession | ||
val ds1 = spark.range(0, 1000000, 1, 8) | ||
val ds3 = ds1 | ||
|
||
val result = Try(assertLargeDatasetEquality(ds1, ds3)) | ||
|
||
blackHole.consume(result) | ||
result.isSuccess | ||
} | ||
|
||
@Benchmark | ||
@BenchmarkMode(Array(Mode.SingleShotTime)) | ||
@Fork(value = 2) | ||
@Warmup(iterations = 10) | ||
@Measurement(iterations = 10) | ||
@OutputTimeUnit(TimeUnit.NANOSECONDS) | ||
def assertLargeDatasetEqualityV2WithCompositePrimaryKey2(blackHole: Blackhole): Boolean = { | ||
val spark = getSparkSession | ||
val ds1 = spark.range(0, 1000000, 1, 8).withColumn("id2", col("id") + 1) | ||
val ds3 = ds1 | ||
val result = Try(assertLargeDatasetEqualityV2(ds1, ds3, primaryKeys = Seq("id", "id2"))) | ||
|
||
blackHole.consume(result) | ||
result.isSuccess | ||
} | ||
|
||
@Benchmark | ||
@BenchmarkMode(Array(Mode.SingleShotTime)) | ||
@Fork(value = 2) | ||
@Warmup(iterations = 10) | ||
@Measurement(iterations = 10) | ||
@OutputTimeUnit(TimeUnit.NANOSECONDS) | ||
def assertLargeDatasetEqualityV2WithCompositePrimaryKey3(blackHole: Blackhole): Boolean = { | ||
val spark = getSparkSession | ||
val ds1 = spark.range(0, 1000000, 1, 8).withColumn("id2", col("id") + 1).withColumn("id3", col("id2") + 1) | ||
val ds3 = ds1 | ||
val result = Try(assertLargeDatasetEqualityV2(ds1, ds3, primaryKeys = Seq("id", "id2", "id3"))) | ||
|
||
blackHole.consume(result) | ||
result.isSuccess | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
package com.github.mrpowers.spark.fast.tests | ||
|
||
import com.github.mrpowers.spark.fast.tests.DatasetUtils.DatasetOps | ||
import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow | ||
import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.{DataFrame, Dataset, Row} | ||
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} | ||
|
||
import scala.reflect.ClassTag | ||
|
||
|
@@ -72,7 +73,7 @@ Expected DataFrame Row Count: '$expectedCount' | |
val e = expectedDS.collect().toSeq | ||
if (!a.approximateSameElements(e, equals)) { | ||
val arr = ("Actual Content", "Expected Content") | ||
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate) | ||
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, Left(a -> e), truncate) | ||
throw DatasetContentMismatch(msg) | ||
} | ||
} | ||
|
@@ -152,6 +153,89 @@ Expected DataFrame Row Count: '$expectedCount' | |
} | ||
} | ||
|
||
/** | ||
* Raises an error unless `actualDS` and `expectedDS` are equal. It is recommended to provide `primaryKeys` to ensure accurate and efficient | ||
* comparison of rows. If primary key is not provided, will try to compare the rows based on their row number. This requires both datasets to be | ||
* partitioned in the same way and become unreliable when shuffling is involved. | ||
* @param primaryKeys | ||
* unique identifier for each row to ensure accurate comparison of rows | ||
* @param checkKeyUniqueness | ||
* if true, will check if the primary key is actually unique | ||
*/ | ||
def assertLargeDatasetEqualityV2[T: ClassTag]( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How should we approach this? Should I add more version that is closer to what we had before? or maybe just use this replace older version with this |
||
actualDS: Dataset[T], | ||
expectedDS: Dataset[T], | ||
equals: Either[(T, T) => Boolean, Option[Column]] = Right(None), | ||
ignoreNullable: Boolean = false, | ||
ignoreColumnNames: Boolean = false, | ||
ignoreColumnOrder: Boolean = false, | ||
ignoreMetadata: Boolean = true, | ||
checkKeyUniqueness: Boolean = false, | ||
primaryKeys: Seq[String] = Seq.empty, | ||
truncate: Int = 500 | ||
): Unit = { | ||
// first check if the schemas are equal | ||
SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata) | ||
val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS | ||
assertLargeDatasetContentEqualityV2(actual, expectedDS, equals, primaryKeys, checkKeyUniqueness, truncate) | ||
} | ||
|
||
def assertLargeDatasetContentEqualityV2[T: ClassTag]( | ||
ds1: Dataset[T], | ||
ds2: Dataset[T], | ||
equals: Either[(T, T) => Boolean, Option[Column]], | ||
primaryKeys: Seq[String], | ||
checkKeyUniqueness: Boolean, | ||
truncate: Int | ||
): Unit = { | ||
try { | ||
ds1.cache() | ||
ds2.cache() | ||
|
||
val actualCount = ds1.count | ||
val expectedCount = ds2.count | ||
|
||
if (actualCount != expectedCount) { | ||
throw DatasetCountMismatch(countMismatchMessage(actualCount, expectedCount)) | ||
} | ||
|
||
if (primaryKeys.nonEmpty && checkKeyUniqueness) { | ||
assert(ds1.isKeyUnique(primaryKeys), "Primary key is not unique in actual dataset") | ||
assert(ds2.isKeyUnique(primaryKeys), "Primary key is not unique in expected dataset") | ||
} | ||
|
||
val joinedDf = ds1 | ||
.joinPair(ds2, primaryKeys) | ||
|
||
val unequalDS = equals match { | ||
case Left(customEquals) => | ||
joinedDf.filter((p: (T, T)) => | ||
// dataset joinWith implicitly return null each side for missing values from outer join even for primitive types | ||
p match { | ||
case (null, null) => false | ||
case (null, _) => true | ||
case (_, null) => true | ||
case (l, r) => !customEquals(l, r) | ||
} | ||
) | ||
|
||
case Right(equalExprOption) => | ||
joinedDf.filter(equalExprOption.getOrElse(col("_1") =!= col("_2"))) | ||
} | ||
|
||
if (!unequalDS.isEmpty) { | ||
val joined = Right(unequalDS.take(truncate).toSeq) | ||
val arr = ("Actual Content", "Expected Content") | ||
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, joined, truncate) | ||
throw DatasetContentMismatch(msg) | ||
} | ||
|
||
} finally { | ||
ds1.unpersist() | ||
ds2.unpersist() | ||
} | ||
} | ||
|
||
def assertApproximateDataFrameEquality( | ||
actualDF: DataFrame, | ||
expectedDF: DataFrame, | ||
|
58 changes: 58 additions & 0 deletions
58
core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package com.github.mrpowers.spark.fast.tests | ||
|
||
import org.apache.spark.sql.expressions.Window | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types.StructType | ||
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, TypedColumn} | ||
|
||
import scala.reflect.ClassTag | ||
|
||
private object DatasetUtils { | ||
implicit class DatasetOps[T: ClassTag](ds: Dataset[T]) { | ||
def zipWithIndex(indexName: String): DataFrame = ds | ||
.orderBy() | ||
.withColumn(indexName, row_number().over(Window.orderBy(monotonically_increasing_id()))) | ||
.select(ds.columns.map(col) :+ col(indexName): _*) | ||
|
||
/** | ||
* Check if the primary key is actually unique | ||
*/ | ||
def isKeyUnique(primaryKey: Seq[String]): Boolean = | ||
ds.select(primaryKey.map(col): _*).distinct.count == ds.count | ||
|
||
def joinPair[P: ClassTag]( | ||
other: Dataset[P], | ||
primaryKeys: Seq[String] | ||
): Dataset[(T, P)] = { | ||
if (primaryKeys.nonEmpty) { | ||
ds | ||
.as("l") | ||
.joinWith(other.as("r"), primaryKeys.map(k => col(s"l.$k") === col(s"r.$k")).reduce(_ && _), "full_outer") | ||
} else { | ||
val indexName = s"index_${java.util.UUID.randomUUID}" | ||
val joined = ds | ||
.zipWithIndex(indexName) | ||
.alias("l") | ||
.joinWith(other.zipWithIndex(indexName).alias("r"), col(s"l.$indexName") === col(s"r.$indexName"), "full_outer") | ||
|
||
joined | ||
.select( | ||
encoderToCol("_1", ds.schema, ds.encoder, Seq(indexName)), | ||
encoderToCol("_2", other.schema, other.encoder, Seq(indexName)) | ||
) | ||
} | ||
} | ||
} | ||
|
||
private def encoderToCol[P: ClassTag](colName: String, schema: StructType, encoder: Encoder[P], key: Seq[String]): TypedColumn[Any, P] = { | ||
val columns = schema.names.map(n => col(s"$colName.$n")) // name from encoder is not reliable | ||
val isRowType = implicitly[ClassTag[P]].runtimeClass == classOf[Row] | ||
val unTypedColumn = | ||
if (columns.length == 1 && !isRowType) | ||
columns.head | ||
else | ||
when(key.map(k => col(s"$colName.$k").isNull).reduce(_ && _), lit(null)).otherwise(struct(columns: _*)) | ||
|
||
unTypedColumn.as(colName).as[P](encoder) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: add performance benchmark for Typed vs column filter
In theory Filter using column should allow better query plan generation