Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft assertLargeDatasetEqualityV2 #181

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: add performance benchmark for Typed vs column filter
In theory Filter using column should allow better query plan generation

Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package com.github.mrpowers.spark.fast.tests

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.openjdk.jmh.annotations._
import org.openjdk.jmh.infra.Blackhole

import java.util.concurrent.TimeUnit
import scala.util.Try

private class DatasetComparerBenchmark extends DatasetComparer {
def getSparkSession: SparkSession = {
val session = SparkSession
.builder()
.master("local")
.appName("spark session")
.getOrCreate()
session.sparkContext.setLogLevel("ERROR")
session
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
@Fork(value = 2)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
def assertLargeDatasetEqualityV2(blackHole: Blackhole): Boolean = {
val spark = getSparkSession
val ds1 = spark.range(0, 1000000, 1, 8)
val ds3 = ds1

val result = Try(assertLargeDatasetEqualityV2(ds1, ds3))

blackHole.consume(result)
result.isSuccess
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
@Fork(value = 2)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
def assertLargeDatasetEqualityV2WithSinglePrimaryKey(blackHole: Blackhole): Boolean = {
val spark = getSparkSession
val ds1 = spark.range(0, 1000000, 1, 8)
val ds3 = ds1

val result = Try(assertLargeDatasetEqualityV2(ds1, ds3, primaryKeys = Seq("id")))

blackHole.consume(result)
result.isSuccess
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
@Fork(value = 2)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
def assertLargeDatasetEquality(blackHole: Blackhole): Boolean = {
val spark = getSparkSession
val ds1 = spark.range(0, 1000000, 1, 8)
val ds3 = ds1

val result = Try(assertLargeDatasetEquality(ds1, ds3))

blackHole.consume(result)
result.isSuccess
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
@Fork(value = 2)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
def assertLargeDatasetEqualityV2WithCompositePrimaryKey2(blackHole: Blackhole): Boolean = {
val spark = getSparkSession
val ds1 = spark.range(0, 1000000, 1, 8).withColumn("id2", col("id") + 1)
val ds3 = ds1
val result = Try(assertLargeDatasetEqualityV2(ds1, ds3, primaryKeys = Seq("id", "id2")))

blackHole.consume(result)
result.isSuccess
}

@Benchmark
@BenchmarkMode(Array(Mode.SingleShotTime))
@Fork(value = 2)
@Warmup(iterations = 10)
@Measurement(iterations = 10)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
def assertLargeDatasetEqualityV2WithCompositePrimaryKey3(blackHole: Blackhole): Boolean = {
val spark = getSparkSession
val ds1 = spark.range(0, 1000000, 1, 8).withColumn("id2", col("id") + 1).withColumn("id3", col("id2") + 1)
val ds3 = ds1
val result = Try(assertLargeDatasetEqualityV2(ds1, ds3, primaryKeys = Seq("id", "id2", "id3")))

blackHole.consume(result)
result.isSuccess
}
}
11 changes: 7 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ lazy val commonSettings = Seq(
else Seq.empty)
},
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
"org.scalatest" %% "scalatest" % "3.2.18" % "test"
"org.scalatest" %% "scalatest" % "3.2.18" % "test"
)
)

lazy val core = (project in file("core"))
.settings(
commonSettings,
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided"
),
name := "core",
Compile / packageSrc / publishArtifact := true,
Compile / packageDoc / publishArtifact := true
Expand All @@ -56,7 +58,8 @@ lazy val benchmarks = (project in file("benchmarks"))
.settings(commonSettings)
.settings(
libraryDependencies ++= Seq(
"org.openjdk.jmh" % "jmh-generator-annprocess" % "1.37" // required for jmh IDEA plugin. Make sure this version matches sbt-jmh version!
"org.apache.spark" %% "spark-sql" % sparkVersion % "compile",
"org.openjdk.jmh" % "jmh-generator-annprocess" % "1.37" // required for jmh IDEA plugin. Make sure this version matches sbt-jmh version!
),
name := "benchmarks",
publish / skip := true
Expand Down Expand Up @@ -121,4 +124,4 @@ updateOptions := updateOptions.value.withLatestSnapshots(false)

import xerial.sbt.Sonatype.sonatypeCentralHost

ThisBuild / sonatypeCredentialHost := sonatypeCentralHost
ThisBuild / sonatypeCredentialHost := sonatypeCentralHost
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.DatasetUtils.DatasetOps
import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow
import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}

import scala.reflect.ClassTag

Expand Down Expand Up @@ -72,7 +73,7 @@ Expected DataFrame Row Count: '$expectedCount'
val e = expectedDS.collect().toSeq
if (!a.approximateSameElements(e, equals)) {
val arr = ("Actual Content", "Expected Content")
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, a, e, truncate)
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, Left(a -> e), truncate)
throw DatasetContentMismatch(msg)
}
}
Expand Down Expand Up @@ -152,6 +153,89 @@ Expected DataFrame Row Count: '$expectedCount'
}
}

/**
* Raises an error unless `actualDS` and `expectedDS` are equal. It is recommended to provide `primaryKeys` to ensure accurate and efficient
* comparison of rows. If primary key is not provided, will try to compare the rows based on their row number. This requires both datasets to be
* partitioned in the same way and become unreliable when shuffling is involved.
* @param primaryKeys
* unique identifier for each row to ensure accurate comparison of rows
* @param checkKeyUniqueness
* if true, will check if the primary key is actually unique
*/
def assertLargeDatasetEqualityV2[T: ClassTag](
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How should we approach this? Should I add more version that is closer to what we had before? or maybe just use this replace older version with this

actualDS: Dataset[T],
expectedDS: Dataset[T],
equals: Either[(T, T) => Boolean, Option[Column]] = Right(None),
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = false,
ignoreMetadata: Boolean = true,
checkKeyUniqueness: Boolean = false,
primaryKeys: Seq[String] = Seq.empty,
truncate: Int = 500
): Unit = {
// first check if the schemas are equal
SchemaComparer.assertDatasetSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder, ignoreMetadata)
val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS
assertLargeDatasetContentEqualityV2(actual, expectedDS, equals, primaryKeys, checkKeyUniqueness, truncate)
}

def assertLargeDatasetContentEqualityV2[T: ClassTag](
ds1: Dataset[T],
ds2: Dataset[T],
equals: Either[(T, T) => Boolean, Option[Column]],
primaryKeys: Seq[String],
checkKeyUniqueness: Boolean,
truncate: Int
): Unit = {
try {
ds1.cache()
ds2.cache()

val actualCount = ds1.count
val expectedCount = ds2.count

if (actualCount != expectedCount) {
throw DatasetCountMismatch(countMismatchMessage(actualCount, expectedCount))
}

if (primaryKeys.nonEmpty && checkKeyUniqueness) {
assert(ds1.isKeyUnique(primaryKeys), "Primary key is not unique in actual dataset")
assert(ds2.isKeyUnique(primaryKeys), "Primary key is not unique in expected dataset")
}

val joinedDf = ds1
.joinPair(ds2, primaryKeys)

val unequalDS = equals match {
case Left(customEquals) =>
joinedDf.filter((p: (T, T)) =>
// dataset joinWith implicitly return null each side for missing values from outer join even for primitive types
p match {
case (null, null) => false
case (null, _) => true
case (_, null) => true
case (l, r) => !customEquals(l, r)
}
)

case Right(equalExprOption) =>
joinedDf.filter(equalExprOption.getOrElse(col("_1") =!= col("_2")))
}

if (!unequalDS.isEmpty) {
val joined = Right(unequalDS.take(truncate).toSeq)
val arr = ("Actual Content", "Expected Content")
val msg = "Diffs\n" ++ ProductUtil.showProductDiff(arr, joined, truncate)
throw DatasetContentMismatch(msg)
}

} finally {
ds1.unpersist()
ds2.unpersist()
}
}

def assertApproximateDataFrameEquality(
actualDF: DataFrame,
expectedDF: DataFrame,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package com.github.mrpowers.spark.fast.tests

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, TypedColumn}

import scala.reflect.ClassTag

private object DatasetUtils {
implicit class DatasetOps[T: ClassTag](ds: Dataset[T]) {
def zipWithIndex(indexName: String): DataFrame = ds
.orderBy()
.withColumn(indexName, row_number().over(Window.orderBy(monotonically_increasing_id())))
.select(ds.columns.map(col) :+ col(indexName): _*)

/**
* Check if the primary key is actually unique
*/
def isKeyUnique(primaryKey: Seq[String]): Boolean =
ds.select(primaryKey.map(col): _*).distinct.count == ds.count

def joinPair[P: ClassTag](
other: Dataset[P],
primaryKeys: Seq[String]
): Dataset[(T, P)] = {
if (primaryKeys.nonEmpty) {
ds
.as("l")
.joinWith(other.as("r"), primaryKeys.map(k => col(s"l.$k") === col(s"r.$k")).reduce(_ && _), "full_outer")
} else {
val indexName = s"index_${java.util.UUID.randomUUID}"
val joined = ds
.zipWithIndex(indexName)
.alias("l")
.joinWith(other.zipWithIndex(indexName).alias("r"), col(s"l.$indexName") === col(s"r.$indexName"), "full_outer")

joined
.select(
encoderToCol("_1", ds.schema, ds.encoder, Seq(indexName)),
encoderToCol("_2", other.schema, other.encoder, Seq(indexName))
)
}
}
}

private def encoderToCol[P: ClassTag](colName: String, schema: StructType, encoder: Encoder[P], key: Seq[String]): TypedColumn[Any, P] = {
val columns = schema.names.map(n => col(s"$colName.$n")) // name from encoder is not reliable
val isRowType = implicitly[ClassTag[P]].runtimeClass == classOf[Row]
val unTypedColumn =
if (columns.length == 1 && !isRowType)
columns.head
else
when(key.map(k => col(s"$colName.$k").isNull).reduce(_ && _), lit(null)).otherwise(struct(columns: _*))

unTypedColumn.as(colName).as[P](encoder)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.Row

import scala.annotation.tailrec
import scala.reflect.ClassTag

object ProductUtil {
Expand All @@ -20,8 +21,7 @@ object ProductUtil {
}
private[mrpowers] def showProductDiff[T: ClassTag](
header: (String, String),
actual: Seq[T],
expected: Seq[T],
data: Either[(Seq[T], Seq[T]), Seq[(T, T)]],
truncate: Int = 20,
minColWidth: Int = 3
): String = {
Expand All @@ -33,7 +33,10 @@ object ProductUtil {

val sb = new StringBuilder

val fullJoin = actual.zipAll(expected, null, null)
val fullJoin = data match {
case Left((actual, expected)) => actual.zipAll(expected, null, null)
case Right(joined) => joined
}

val diff = fullJoin.map { case (actualRow, expectedRow) =>
if (actualRow == expectedRow) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ object SchemaComparer {
private def betterSchemaMismatchMessage(actualSchema: StructType, expectedSchema: StructType): String = {
showProductDiff(
("Actual Schema", "Expected Schema"),
actualSchema.fields,
expectedSchema.fields,
Left(actualSchema.fields.toSeq -> expectedSchema.fields.toSeq),
truncate = 200
)
}
Expand Down
Loading
Loading